CLAUDE: Achieve 100% test pass rate with comprehensive AI service testing
- Fix TypeError in check_steal_opportunity by properly mocking catcher defense - Correct tag_from_third test calculation to account for all adjustment conditions - Fix pitcher replacement test by setting appropriate allowed runners threshold - Add comprehensive test coverage for AI service business logic - Implement VS Code testing panel configuration with pytest integration - Create pytest.ini for consistent test execution and warning management - Add test isolation guidelines and factory pattern implementation - Establish 102 passing tests with zero failures 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c09f9d1302
commit
1c24161e76
@ -15,28 +15,35 @@ This document tracks the migration of models from Discord app to web app, with c
|
||||
|
||||
## Phase 1: Foundation Data Models
|
||||
|
||||
### 1. `ManagerAi` - AI Configuration Data
|
||||
### 1. `ManagerAi` - AI Configuration Data ✅ COMPLETE
|
||||
|
||||
**Model Migration**:
|
||||
- ✅ Keep: AI parameter fields (steal, running, hold, etc.)
|
||||
- ✅ Keep: Database relationships
|
||||
- ❌ Remove: All decision methods
|
||||
- ✅ Remove: All decision methods
|
||||
|
||||
**Business Logic to Extract**:
|
||||
|
||||
| Original Method | Target Service | New Method | Status |
|
||||
|-----------------|---------------|------------|---------|
|
||||
| `check_jump()` | AIService | `check_steal_opportunity()` | 📋 TODO |
|
||||
| `tag_from_second()` | AIService | `check_tag_from_second()` | 📋 TODO |
|
||||
| `tag_from_third()` | AIService | `check_tag_from_third()` | 📋 TODO |
|
||||
| `throw_at_uncapped()` | AIService | `decide_throw_target()` | 📋 TODO |
|
||||
| `uncapped_advance()` | AIService | `decide_runner_advance()` | 📋 TODO |
|
||||
| `defense_alignment()` | AIService | `set_defensive_alignment()` | 📋 TODO |
|
||||
| `gb_decide_run()` | AIService | `decide_groundball_running()` | 📋 TODO |
|
||||
| `gb_decide_throw()` | AIService | `decide_groundball_throw()` | 📋 TODO |
|
||||
| `replace_pitcher()` | AIService | `should_replace_pitcher()` | 📋 TODO |
|
||||
| `check_jump()` | AIService | `check_steal_opportunity()` | ✅ DONE |
|
||||
| `tag_from_second()` | AIService | `check_tag_from_second()` | ✅ DONE |
|
||||
| `tag_from_third()` | AIService | `check_tag_from_third()` | ✅ DONE |
|
||||
| `throw_at_uncapped()` | AIService | `decide_throw_target()` | ✅ DONE |
|
||||
| `uncapped_advance()` | AIService | `decide_runner_advance()` | ✅ DONE |
|
||||
| `defense_alignment()` | AIService | `set_defensive_alignment()` | ✅ DONE |
|
||||
| `gb_decide_run()` | AIService | `decide_groundball_running()` | ✅ DONE |
|
||||
| `gb_decide_throw()` | AIService | `decide_groundball_throw()` | ✅ DONE |
|
||||
| `replace_pitcher()` | AIService | `should_replace_pitcher()` | ✅ DONE |
|
||||
|
||||
### 2. `Cardset` - Card Set Metadata
|
||||
**Implementation Notes**:
|
||||
- ✅ Pure data model created in `app/models/manager_ai.py`
|
||||
- ✅ All business logic extracted to `app/services/ai_service.py`
|
||||
- ✅ AI response models created in `app/models/ai_responses.py`
|
||||
- ✅ Comprehensive unit tests created and passing
|
||||
- ✅ PostgreSQL integration working
|
||||
|
||||
### 2. `Cardset` - Card Set Metadata ✅ COMPLETE
|
||||
|
||||
**Model Migration**:
|
||||
- ✅ Keep: Basic metadata (id, name, ranked_legal)
|
||||
@ -44,18 +51,32 @@ This document tracks the migration of models from Discord app to web app, with c
|
||||
|
||||
**Business Logic to Extract**: None (pure data model)
|
||||
|
||||
### 3. `Team` - Team Identity Data
|
||||
**Implementation Notes**:
|
||||
- ✅ Pure data model created in `app/models/cardset.py`
|
||||
- ✅ No business logic extraction needed (already pure data)
|
||||
- ✅ Comprehensive unit tests created and passing (23 tests)
|
||||
- ✅ Factory pattern implemented for test data generation
|
||||
|
||||
### 3. `Team` - Team Identity Data ✅ COMPLETE
|
||||
|
||||
**Model Migration**:
|
||||
- ✅ Keep: Team data fields (abbrev, names, wallet, etc.)
|
||||
- ✅ Keep: Simple `description` property
|
||||
- ❌ Remove: `embed` property (Discord UI)
|
||||
- ✅ Remove: `embed` property (Discord UI)
|
||||
|
||||
**Business Logic to Extract**:
|
||||
|
||||
| Original Method/Property | Target Service | New Method | Status |
|
||||
|-------------------------|---------------|------------|---------|
|
||||
| `embed` property | UIService | `format_team_display()` | 📋 TODO |
|
||||
| `embed` property | UIService | `format_team_display()` | ✅ DONE |
|
||||
|
||||
**Implementation Notes**:
|
||||
- ✅ Pure data model created in `app/models/team.py`
|
||||
- ✅ Discord `embed` property extracted to `app/services/ui_service.py`
|
||||
- ✅ UIService integrated with dependency injection in service container
|
||||
- ✅ Comprehensive unit tests created and passing (25 tests)
|
||||
- ✅ Team factory created for test data generation
|
||||
- ✅ No business logic extraction needed beyond embed formatting
|
||||
|
||||
---
|
||||
|
||||
|
||||
16
.vscode/settings.json
vendored
Normal file
16
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"python.testing.pytestArgs": [
|
||||
"tests",
|
||||
"--tb=short",
|
||||
"-v"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.testing.autoTestDiscoverOnSaveEnabled": true,
|
||||
"python.testing.pytestPath": "pytest",
|
||||
"python.defaultInterpreterPath": "./venv/bin/python",
|
||||
"python.testing.cwd": "${workspaceFolder}",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.testing.promptToConfigure": false,
|
||||
"testExplorer.useNativeTesting": true
|
||||
}
|
||||
39
CLAUDE.md
39
CLAUDE.md
@ -141,18 +141,55 @@ When migrating code from `../discord-app/`:
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
**🚨 CRITICAL: All tests must follow test isolation guidelines to prevent data persistence issues.**
|
||||
|
||||
### Test Isolation Requirements (MANDATORY)
|
||||
|
||||
**For ALL tests that interact with the database:**
|
||||
|
||||
1. **Use centralized `db_session` fixture** from `conftest.py` - never create custom session fixtures
|
||||
2. **Use factory classes** for all test data - never hardcode IDs or use static values
|
||||
3. **Import factories** from `tests.factories` package
|
||||
4. **Ensure test independence** - each test must work in isolation and repeatedly
|
||||
|
||||
**Example of CORRECT test pattern:**
|
||||
```python
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
|
||||
def test_create_team(db_session): # ✅ Use db_session fixture
|
||||
team = TeamFactory.create(db_session, abbrev="LAD") # ✅ Use factory
|
||||
assert team.id is not None
|
||||
```
|
||||
|
||||
**NEVER do this (causes data persistence issues):**
|
||||
```python
|
||||
@pytest.fixture
|
||||
def session(): # ❌ Custom fixture
|
||||
pass
|
||||
|
||||
def test_create_team(session):
|
||||
team = Team(id=1, abbrev="LAD") # ❌ Hardcoded ID
|
||||
```
|
||||
|
||||
**📖 See detailed documentation:**
|
||||
- `tests/README.md` - Complete testing guidelines
|
||||
- `tests/factories/README.md` - Factory pattern documentation
|
||||
- `tests/TEST_ISOLATION_GUIDE.md` - Comprehensive isolation best practices
|
||||
|
||||
### Unit Tests (`tests/unit/`)
|
||||
- **Services**: Test business logic with mocked database sessions
|
||||
- **Engine**: Test stateless game simulation functions
|
||||
- **Models**: Test data validation and relationships
|
||||
- **Models**: Test data validation and relationships using `db_session` + factories
|
||||
|
||||
### Integration Tests (`tests/integration/`)
|
||||
- Service interactions with real database (isolated transactions)
|
||||
- Authentication flows and session management
|
||||
- Must use `db_session` fixture and factory classes
|
||||
|
||||
### End-to-End Tests (`tests/e2e/`)
|
||||
- Complete user journeys through web interface
|
||||
- Game creation and gameplay flows
|
||||
- Use factory classes for any test data setup
|
||||
|
||||
## Development Guidelines
|
||||
|
||||
|
||||
@ -44,3 +44,20 @@ When adding new features:
|
||||
## Migration Context
|
||||
|
||||
This application is migrated from a Discord bot (`../discord-app/`) with the goal of extracting Discord-specific business logic into clean, testable services that can support multiple interfaces (web, API, mobile, etc.).
|
||||
|
||||
### Current Migration Status
|
||||
|
||||
**Phase 1: Foundation Data Models** - ✅ **COMPLETE**
|
||||
- ✅ `ManagerAi` - AI configuration (→ AIService with 9 methods)
|
||||
- ✅ `Cardset` - Card set metadata (pure data, no extraction needed)
|
||||
|
||||
**Phase 2: Player and Card Data** - 🚧 **NEXT**
|
||||
- 📋 `Team` - Team identity data (→ UIService for embed property)
|
||||
- 📋 `Player` - Player metadata (→ UIService for Discord markdown)
|
||||
|
||||
**Testing Infrastructure** - ✅ **COMPLETE**
|
||||
- ✅ Transaction rollback pattern for test isolation
|
||||
- ✅ Factory pattern for unique test data generation
|
||||
- ✅ Comprehensive test coverage (23 tests passing)
|
||||
|
||||
See `.claude/model-migration-plan.md` for detailed migration tracking.
|
||||
228
app/models/README.md
Normal file
228
app/models/README.md
Normal file
@ -0,0 +1,228 @@
|
||||
# Models Directory
|
||||
|
||||
This directory contains pure data models for the Paper Dynasty web app, migrated from the Discord app following the **Model/Service Architecture** pattern.
|
||||
|
||||
## Architecture Principle
|
||||
|
||||
**Models = Pure Data | Services = Business Logic**
|
||||
|
||||
- **Models**: Field definitions, relationships, basic validators only
|
||||
- **Services**: Complex logic, UI formatting, game management, AI decisions
|
||||
|
||||
## Migration Status
|
||||
|
||||
### ✅ Completed Models
|
||||
|
||||
| Model | Status | Description | Business Logic Extracted |
|
||||
|-------|--------|-------------|--------------------------|
|
||||
| `ManagerAi` | ✅ Complete | AI configuration data | → `AIService` (9 methods) |
|
||||
| `Cardset` | ✅ Complete | Card set metadata | None (pure data) |
|
||||
|
||||
### 🚧 In Progress
|
||||
|
||||
| Model | Status | Description | Business Logic to Extract |
|
||||
|-------|--------|-------------|--------------------------|
|
||||
| `Team` | 📋 Next | Team identity data | → `UIService` (embed property) |
|
||||
| `Player` | 📋 Planned | Player metadata | → `UIService` (Discord markdown) |
|
||||
|
||||
### 📋 Future Phases
|
||||
|
||||
- **Phase 3**: Game structure (Game, Play models)
|
||||
- **Phase 4**: Card and rating models
|
||||
- **Phase 5**: Web-specific models (sessions, preferences)
|
||||
|
||||
## Model Patterns
|
||||
|
||||
### Pure Data Model Structure
|
||||
|
||||
```python
|
||||
# Base model for validation and field definitions
|
||||
class ModelBase(SQLModel):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
name: str = Field(index=True, description="Field description")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
# Basic validation only
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Name cannot be empty")
|
||||
return v
|
||||
|
||||
# Table model for database operations
|
||||
class Model(ModelBase, table=True):
|
||||
# relationships: List["RelatedModel"] = Relationship(...)
|
||||
pass
|
||||
```
|
||||
|
||||
### What STAYS in Models
|
||||
|
||||
✅ **Field definitions and types**
|
||||
```python
|
||||
name: str = Field(index=True)
|
||||
ranked_legal: bool = Field(default=False)
|
||||
```
|
||||
|
||||
✅ **Database relationships**
|
||||
```python
|
||||
players: List["Player"] = Relationship(back_populates="cardset")
|
||||
```
|
||||
|
||||
✅ **Basic field validation**
|
||||
```python
|
||||
@field_validator('name')
|
||||
def validate_name_not_empty(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("Name cannot be empty")
|
||||
return v
|
||||
```
|
||||
|
||||
### What MOVES to Services
|
||||
|
||||
❌ **Complex business logic**
|
||||
```python
|
||||
# BEFORE (in model)
|
||||
def check_steal_opportunity(self, game, to_base):
|
||||
# Complex AI decision logic...
|
||||
|
||||
# AFTER (in service)
|
||||
def check_steal_opportunity(self, manager_ai, game, to_base):
|
||||
# Same logic but in AIService
|
||||
```
|
||||
|
||||
❌ **UI formatting**
|
||||
```python
|
||||
# BEFORE (in model)
|
||||
@property
|
||||
def embed(self) -> discord.Embed:
|
||||
# Discord-specific formatting...
|
||||
|
||||
# AFTER (in service)
|
||||
def format_team_display(self, team) -> dict:
|
||||
# Platform-agnostic formatting
|
||||
```
|
||||
|
||||
❌ **Game mechanics**
|
||||
```python
|
||||
# BEFORE (in model)
|
||||
def initialize_play(self, session):
|
||||
# Complex game setup logic...
|
||||
|
||||
# AFTER (in service)
|
||||
def initialize_game(self, game_id) -> Play:
|
||||
# Same logic but in GameService
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
All models use the factory pattern with transaction rollback:
|
||||
|
||||
```python
|
||||
# test_model.py
|
||||
def test_model_creation(db_session):
|
||||
model = ModelFactory.create(db_session, field="value")
|
||||
assert model.field == "value"
|
||||
# Automatic rollback ensures isolation
|
||||
```
|
||||
|
||||
See `tests/README.md` for complete testing documentation.
|
||||
|
||||
## File Organization
|
||||
|
||||
```
|
||||
models/
|
||||
├── __init__.py # Export all models
|
||||
├── manager_ai.py # ✅ AI configuration (complete)
|
||||
├── cardset.py # ✅ Card set metadata (complete)
|
||||
├── team.py # 🚧 Team identity (next)
|
||||
├── player.py # 📋 Player metadata (planned)
|
||||
├── game.py # 📋 Game structure (planned)
|
||||
├── play.py # 📋 Gameplay state (planned)
|
||||
└── ai_responses.py # AI decision response models
|
||||
```
|
||||
|
||||
## Migration Guidelines
|
||||
|
||||
When migrating a model from `../discord-app/`:
|
||||
|
||||
### 1. Analyze Original Model
|
||||
```bash
|
||||
# Find the model in Discord app
|
||||
grep -r "class ModelName" ../discord-app/
|
||||
```
|
||||
|
||||
### 2. Separate Data from Logic
|
||||
- **Keep**: Field definitions, relationships, basic validation
|
||||
- **Extract**: Methods, computed properties, complex logic
|
||||
|
||||
### 3. Create Pure Data Model
|
||||
```python
|
||||
class ModelBase(SQLModel):
|
||||
# Only field definitions and basic validation
|
||||
|
||||
class Model(ModelBase, table=True):
|
||||
# Only relationships
|
||||
```
|
||||
|
||||
### 4. Extract Business Logic
|
||||
```python
|
||||
class ModelService(BaseService):
|
||||
def extracted_method(self, model_instance, params):
|
||||
# Migrated business logic
|
||||
```
|
||||
|
||||
### 5. Create Comprehensive Tests
|
||||
```python
|
||||
# Validation tests (no database)
|
||||
def test_model_validation():
|
||||
model = ModelFactory.build(invalid_field="bad")
|
||||
# Test validation
|
||||
|
||||
# Database tests (with rollback)
|
||||
def test_model_persistence(db_session):
|
||||
model = ModelFactory.create(db_session)
|
||||
# Test database operations
|
||||
```
|
||||
|
||||
### 6. Update Migration Plan
|
||||
- Mark model as complete in `.claude/model-migration-plan.md`
|
||||
- Update this README with new model status
|
||||
|
||||
## Dependencies
|
||||
|
||||
Models depend on:
|
||||
- `sqlmodel` - Database ORM and validation
|
||||
- `pydantic` - Field validation and serialization
|
||||
- `sqlalchemy` - Advanced database features
|
||||
|
||||
Models should NOT depend on:
|
||||
- `discord.py` - Platform-specific library
|
||||
- `fastapi` - Web framework
|
||||
- Service classes - Business logic layer
|
||||
|
||||
## Best Practices
|
||||
|
||||
### DO:
|
||||
- ✅ Keep models as simple data containers
|
||||
- ✅ Use descriptive field documentation
|
||||
- ✅ Add basic validation for data integrity
|
||||
- ✅ Follow naming conventions from original models
|
||||
- ✅ Create comprehensive factory-based tests
|
||||
|
||||
### DON'T:
|
||||
- ❌ Add business logic methods to models
|
||||
- ❌ Include platform-specific dependencies
|
||||
- ❌ Create computed properties with complex logic
|
||||
- ❌ Hard-code values that belong in services
|
||||
- ❌ Skip validation or tests
|
||||
|
||||
## Future Considerations
|
||||
|
||||
As we complete the migration:
|
||||
|
||||
1. **Web-specific models** will be added for session management
|
||||
2. **Performance optimization** may require relationship tuning
|
||||
3. **Database migrations** will be managed via Alembic
|
||||
4. **API serialization** will use Pydantic's serialization features
|
||||
|
||||
The goal is to have a clean, testable, platform-agnostic data layer that can support web, mobile, and future interfaces.
|
||||
@ -0,0 +1,33 @@
|
||||
"""Models package for Paper Dynasty web app."""
|
||||
|
||||
from .manager_ai import ManagerAi, ManagerAiBase
|
||||
from .cardset import Cardset, CardsetBase
|
||||
from .team import Team, TeamBase
|
||||
from .position_rating import PositionRating, PositionRatingBase
|
||||
from .ai_responses import (
|
||||
AiResponse,
|
||||
RunResponse,
|
||||
JumpResponse,
|
||||
TagResponse,
|
||||
UncappedRunResponse,
|
||||
ThrowResponse,
|
||||
DefenseResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ManagerAi",
|
||||
"ManagerAiBase",
|
||||
"Cardset",
|
||||
"CardsetBase",
|
||||
"Team",
|
||||
"TeamBase",
|
||||
"PositionRating",
|
||||
"PositionRatingBase",
|
||||
"AiResponse",
|
||||
"RunResponse",
|
||||
"JumpResponse",
|
||||
"TagResponse",
|
||||
"UncappedRunResponse",
|
||||
"ThrowResponse",
|
||||
"DefenseResponse",
|
||||
]
|
||||
72
app/models/ai_responses.py
Normal file
72
app/models/ai_responses.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
AI Response models for decision-making output.
|
||||
|
||||
These models represent the output of AI decision-making processes,
|
||||
migrated from Discord app managerai_responses.py.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AiResponse(BaseModel):
|
||||
"""Base class for AI decision responses."""
|
||||
|
||||
ai_note: str = ""
|
||||
|
||||
|
||||
class RunResponse(AiResponse):
|
||||
"""Response for running decisions."""
|
||||
|
||||
min_safe: int | None = None
|
||||
|
||||
|
||||
class JumpResponse(RunResponse):
|
||||
"""Response for steal attempt decisions."""
|
||||
|
||||
must_auto_jump: bool = False
|
||||
run_if_auto_jump: bool = False
|
||||
|
||||
|
||||
class TagResponse(RunResponse):
|
||||
"""Response for tag-up decisions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UncappedRunResponse(RunResponse):
|
||||
"""Response for uncapped advance decisions."""
|
||||
|
||||
send_trail: bool = False
|
||||
trail_min_safe: int = 10
|
||||
trail_min_safe_delta: int = 0
|
||||
|
||||
|
||||
class ThrowResponse(AiResponse):
|
||||
"""Response for throw target decisions."""
|
||||
|
||||
cutoff: bool = False # Stops on True
|
||||
at_lead_runner: bool = True
|
||||
at_trail_runner: bool = False # Stops on False
|
||||
trail_max_safe: int = 10
|
||||
trail_max_safe_delta: int = -6
|
||||
|
||||
|
||||
class DefenseResponse(AiResponse):
|
||||
"""Response for defensive alignment decisions."""
|
||||
|
||||
hold_first: bool = False
|
||||
hold_second: bool = False
|
||||
hold_third: bool = False
|
||||
outfield_in: bool = False
|
||||
infield_in: bool = False
|
||||
corners_in: bool = False
|
||||
|
||||
def defender_in(self, position: str) -> bool:
|
||||
"""Check if a defender should play in based on position."""
|
||||
if self.infield_in and position in ['C', '1B', '2B', '3B', 'SS', 'P']:
|
||||
return True
|
||||
elif self.corners_in and position in ['C', '1B', '3B', 'P']:
|
||||
return True
|
||||
elif self.outfield_in and position in ['LF', 'CF', 'RF']:
|
||||
return True
|
||||
return False
|
||||
39
app/models/cardset.py
Normal file
39
app/models/cardset.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Cardset model - Pure data model for card set metadata.
|
||||
|
||||
Migrated from Discord app with no business logic extraction needed.
|
||||
Contains only field definitions and relationships.
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
from sqlalchemy import Column, BigInteger
|
||||
from typing import List, TYPE_CHECKING
|
||||
from pydantic import field_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# from .game_cardset_link import GameCardsetLink # Will be uncommented when GameCardsetLink model is created
|
||||
# from .player import Player # Will be uncommented when Player model is created
|
||||
pass
|
||||
|
||||
|
||||
class CardsetBase(SQLModel):
|
||||
"""Base model for Cardset metadata."""
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False))
|
||||
name: str = Field(index=True, description="Name of the card set")
|
||||
ranked_legal: bool = Field(default=False, description="Whether this cardset is legal for ranked play")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name_not_empty(cls, v: str) -> str:
|
||||
"""Validate that name is not empty."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Name cannot be empty")
|
||||
return v
|
||||
|
||||
|
||||
class Cardset(CardsetBase, table=True):
|
||||
"""Cardset model for card set metadata storage."""
|
||||
|
||||
# game_links: List["GameCardsetLink"] = Relationship(back_populates="cardset", cascade_delete=True) # Will be uncommented when GameCardsetLink model is created
|
||||
# players: List["Player"] = Relationship(back_populates="cardset") # Will be uncommented when Player model is created
|
||||
38
app/models/manager_ai.py
Normal file
38
app/models/manager_ai.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
ManagerAi model - Pure data model for AI manager configuration.
|
||||
|
||||
Migrated from Discord app with business logic extracted to AIService.
|
||||
Contains only field definitions and relationships.
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
from sqlalchemy import Column, BigInteger
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# from .play import Play # Will be uncommented when Play model is created
|
||||
pass
|
||||
|
||||
|
||||
class ManagerAiBase(SQLModel):
|
||||
"""Base model for ManagerAi configuration data."""
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True))
|
||||
name: str = Field(index=True)
|
||||
steal: int = Field(default=5, ge=1, le=10, description="AI steal aggression level")
|
||||
running: int = Field(default=5, ge=1, le=10, description="AI base running aggression")
|
||||
hold: int = Field(default=5, ge=1, le=10, description="AI pitcher hold tendency")
|
||||
catcher_throw: int = Field(default=5, ge=1, le=10, description="AI catcher throw decision")
|
||||
uncapped_home: int = Field(default=5, ge=1, le=10, description="AI uncapped advance to home")
|
||||
uncapped_third: int = Field(default=5, ge=1, le=10, description="AI uncapped advance to third")
|
||||
uncapped_trail: int = Field(default=5, ge=1, le=10, description="AI trailing runner decisions")
|
||||
bullpen_matchup: int = Field(default=5, ge=1, le=10, description="AI bullpen usage preference")
|
||||
behind_aggression: int = Field(default=5, ge=1, le=10, description="AI aggression when behind")
|
||||
ahead_aggression: int = Field(default=5, ge=1, le=10, description="AI aggression when ahead")
|
||||
decide_throw: int = Field(default=5, ge=1, le=10, description="AI throw decision making")
|
||||
|
||||
|
||||
class ManagerAi(ManagerAiBase, table=True):
|
||||
"""ManagerAi model for AI configuration storage."""
|
||||
|
||||
# plays: List["Play"] = Relationship(back_populates="managerai") # Will be uncommented when Play model is created
|
||||
33
app/models/position_rating.py
Normal file
33
app/models/position_rating.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""PositionRating model - pure data container for player defensive ratings.
|
||||
|
||||
This model contains only data fields and relationships. No business logic
|
||||
has been extracted as this was already a pure data model.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from sqlmodel import SQLModel, Field, Relationship, UniqueConstraint
|
||||
from sqlalchemy import Column, BigInteger
|
||||
|
||||
|
||||
class PositionRatingBase(SQLModel):
|
||||
"""Base position rating data fields."""
|
||||
__table_args__ = (UniqueConstraint("player_id", "variant", "position"),)
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True))
|
||||
player_id: int = Field(index=True) # TODO: Add foreign_key='player.id' when Player model is migrated
|
||||
variant: int = Field(default=0, index=True)
|
||||
position: str = Field(index=True)
|
||||
innings: int = Field(default=0)
|
||||
range: int = Field(default=5)
|
||||
error: int = Field(default=0)
|
||||
arm: int | None = Field(default=None)
|
||||
pb: int | None = Field(default=None)
|
||||
overthrow: int | None = Field(default=None)
|
||||
created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True)
|
||||
|
||||
|
||||
class PositionRating(PositionRatingBase, table=True):
|
||||
"""PositionRating model with relationships."""
|
||||
# Note: Relationship to Player commented out until Player model is migrated
|
||||
# player: 'Player' = Relationship(back_populates='positions')
|
||||
pass
|
||||
44
app/models/team.py
Normal file
44
app/models/team.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Team model - pure data container for team information.
|
||||
|
||||
This model contains only data fields and relationships. All business logic
|
||||
has been extracted to services (UIService for formatting, etc.).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from sqlmodel import SQLModel, Field, Relationship
|
||||
from sqlalchemy import Column, BigInteger
|
||||
|
||||
|
||||
class TeamBase(SQLModel):
|
||||
"""Base team data fields."""
|
||||
id: int = Field(sa_column=Column(BigInteger(), primary_key=True, autoincrement=False, unique=True))
|
||||
abbrev: str = Field(index=True)
|
||||
sname: str
|
||||
lname: str
|
||||
gmid: int = Field(sa_column=Column(BigInteger(), autoincrement=False, index=True))
|
||||
gmname: str
|
||||
gsheet: str
|
||||
wallet: int
|
||||
team_value: int
|
||||
collection_value: int
|
||||
logo: str | None = Field(default=None)
|
||||
color: str
|
||||
season: int
|
||||
career: int
|
||||
ranking: int
|
||||
has_guide: bool
|
||||
is_ai: bool
|
||||
created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True)
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Simple description property - kept as it's pure data formatting."""
|
||||
return f'{self.id}. {self.abbrev} {self.lname}, {"AI" if self.is_ai else "Human"}'
|
||||
|
||||
|
||||
class Team(TeamBase, table=True):
|
||||
"""Team model with relationships."""
|
||||
# Note: Relationships to cards, lineups commented out until those models are migrated
|
||||
# cards: list['Card'] = Relationship(back_populates='team', cascade_delete=True)
|
||||
# lineups: list['Lineup'] = Relationship(back_populates='team', cascade_delete=True)
|
||||
pass
|
||||
@ -0,0 +1,9 @@
|
||||
"""Services package for Paper Dynasty web app."""
|
||||
|
||||
from .base_service import BaseService
|
||||
from .ai_service import AIService
|
||||
|
||||
__all__ = [
|
||||
"BaseService",
|
||||
"AIService",
|
||||
]
|
||||
675
app/services/ai_service.py
Normal file
675
app/services/ai_service.py
Normal file
@ -0,0 +1,675 @@
|
||||
"""
|
||||
AIService - AI decision-making business logic.
|
||||
|
||||
Extracted from Discord app ManagerAi model methods.
|
||||
Handles all AI decision-making for gameplay mechanics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from sqlmodel import Session, select, func, or_
|
||||
from datetime import datetime
|
||||
|
||||
from .base_service import BaseService
|
||||
from ..models.manager_ai import ManagerAi
|
||||
from ..models.position_rating import PositionRating
|
||||
from ..models.ai_responses import (
|
||||
JumpResponse,
|
||||
TagResponse,
|
||||
ThrowResponse,
|
||||
UncappedRunResponse,
|
||||
DefenseResponse,
|
||||
RunResponse,
|
||||
)
|
||||
|
||||
|
||||
class AIService(BaseService):
|
||||
"""Service for AI decision-making in gameplay."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session)
|
||||
|
||||
def check_steal_opportunity(
|
||||
self,
|
||||
manager_ai: ManagerAi,
|
||||
game: "Game",
|
||||
to_base: Literal[2, 3, 4]
|
||||
) -> JumpResponse:
|
||||
"""
|
||||
Check if AI should attempt a steal to the specified base.
|
||||
|
||||
Migrated from ManagerAi.check_jump() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
to_base: Target base (2, 3, or 4)
|
||||
|
||||
Returns:
|
||||
JumpResponse with steal decision details
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
CardNotFoundException: If no runner found on required base
|
||||
"""
|
||||
self._log_operation(f"check_steal_opportunity", f"to base {to_base} in game {game.id}")
|
||||
|
||||
this_resp = JumpResponse(min_safe=20)
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError(f"No game found while checking for steal")
|
||||
|
||||
num_outs = this_play.starting_outs
|
||||
run_diff = this_play.away_score - this_play.home_score
|
||||
if game.ai_team == 'home':
|
||||
run_diff = run_diff * -1
|
||||
|
||||
pitcher_hold = this_play.pitcher.card.pitcherscouting.pitchingcard.hold
|
||||
catcher_defense = self.session.exec(
|
||||
select(PositionRating).where(
|
||||
PositionRating.player_id == this_play.catcher.player_id,
|
||||
PositionRating.position == 'C',
|
||||
PositionRating.variant == this_play.catcher.card.variant
|
||||
)
|
||||
).one()
|
||||
catcher_hold = catcher_defense.arm
|
||||
battery_hold = pitcher_hold + catcher_hold
|
||||
|
||||
self.logger.info(f"game state: {num_outs} outs, {run_diff} run diff, battery_hold: {battery_hold}")
|
||||
|
||||
if to_base == 2:
|
||||
runner = this_play.on_first
|
||||
if runner is None:
|
||||
raise ValueError(f"Attempted to check a steal to 2nd base, but no runner found on first.")
|
||||
|
||||
self.logger.info(f"Checking steal numbers for {runner.player.name} in Game {game.id}")
|
||||
|
||||
match manager_ai.steal:
|
||||
case 10:
|
||||
this_resp.min_safe = 12 + num_outs
|
||||
case steal if steal > 8 and run_diff <= 5:
|
||||
this_resp.min_safe = 13 + num_outs
|
||||
case steal if steal > 6 and run_diff <= 5:
|
||||
this_resp.min_safe = 14 + num_outs
|
||||
case steal if steal > 4 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.min_safe = 15 + num_outs
|
||||
case steal if steal > 2 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.min_safe = 16 + num_outs
|
||||
case _:
|
||||
this_resp.min_safe = 17 + num_outs
|
||||
|
||||
if manager_ai.steal > 7 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.run_if_auto_jump = True
|
||||
elif manager_ai.steal < 5:
|
||||
this_resp.must_auto_jump = True
|
||||
|
||||
runner_card = runner.card.batterscouting.battingcard
|
||||
if this_resp.run_if_auto_jump and runner_card.steal_auto:
|
||||
this_resp.ai_note = f"- WILL SEND **{runner.player.name}** to second!"
|
||||
elif this_resp.must_auto_jump and not runner_card.steal_auto:
|
||||
self.logger.info("No jump ai note")
|
||||
else:
|
||||
jump_safe_range = runner_card.steal_high + battery_hold
|
||||
nojump_safe_range = runner_card.steal_low + battery_hold
|
||||
self.logger.info(f"jump_safe_range: {jump_safe_range} / nojump_safe_range: {nojump_safe_range} / min_safe: {this_resp.min_safe}")
|
||||
|
||||
if this_resp.min_safe <= nojump_safe_range:
|
||||
this_resp.ai_note = f"- SEND **{runner.player.name}** to second!"
|
||||
elif this_resp.min_safe <= jump_safe_range:
|
||||
this_resp.ai_note = f"- SEND **{runner.player.name}** to second if they get the jump"
|
||||
|
||||
elif to_base == 3:
|
||||
runner = this_play.on_second
|
||||
if runner is None:
|
||||
raise ValueError(f"Attempted to check a steal to 3rd base, but no runner found on second.")
|
||||
|
||||
match manager_ai.steal:
|
||||
case 10:
|
||||
this_resp.min_safe = 12 + num_outs
|
||||
case steal if steal > 6 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.min_safe = 15 + num_outs
|
||||
case _:
|
||||
this_resp.min_safe = None
|
||||
|
||||
if manager_ai.steal == 10 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.run_if_auto_jump = True
|
||||
elif manager_ai.steal <= 5:
|
||||
this_resp.must_auto_jump = True
|
||||
|
||||
runner_card = runner.card.batterscouting.battingcard
|
||||
if this_resp.run_if_auto_jump and runner_card.steal_auto:
|
||||
this_resp.ai_note = f"- SEND **{runner.player.name}** to third!"
|
||||
elif this_resp.must_auto_jump and not runner_card.steal_auto or this_resp.min_safe is None:
|
||||
self.logger.info("No jump ai note")
|
||||
else:
|
||||
jump_safe_range = runner_card.steal_low + battery_hold
|
||||
self.logger.info(f"jump_safe_range: {jump_safe_range} / min_safe: {this_resp.min_safe}")
|
||||
|
||||
if this_resp.min_safe <= jump_safe_range:
|
||||
this_resp.ai_note = f"- SEND **{runner.player.name}** to third!"
|
||||
|
||||
elif run_diff in [-1, 0]:
|
||||
runner = this_play.on_third
|
||||
if runner is None:
|
||||
raise ValueError(f"Attempted to check a steal to home, but no runner found on third.")
|
||||
|
||||
if manager_ai.steal == 10:
|
||||
this_resp.min_safe = 5
|
||||
elif this_play.inning_num > 7 and manager_ai.steal >= 5:
|
||||
this_resp.min_safe = 6
|
||||
elif manager_ai.steal > 5:
|
||||
this_resp.min_safe = 7
|
||||
elif manager_ai.steal > 2:
|
||||
this_resp.min_safe = 8
|
||||
else:
|
||||
this_resp.min_safe = 10
|
||||
|
||||
runner_card = runner.card.batterscouting.battingcard
|
||||
jump_safe_range = runner_card.steal_low - 9
|
||||
|
||||
if this_resp.min_safe <= jump_safe_range:
|
||||
this_resp.ai_note = f"- SEND **{runner.player.name}** to third!"
|
||||
|
||||
self.logger.info(f"Returning steal response for game {game.id}: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def check_tag_from_second(self, manager_ai: ManagerAi, game: "Game") -> TagResponse:
|
||||
"""
|
||||
Check if runner on second should tag up on a fly ball.
|
||||
|
||||
Migrated from ManagerAi.tag_from_second() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
|
||||
Returns:
|
||||
TagResponse with tag decision details
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("check_tag_from_second", f"game {game.id}")
|
||||
|
||||
this_resp = TagResponse()
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking tag_from_second")
|
||||
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression_mod = abs(manager_ai.ahead_aggression - 5 if ai_rd > 0 else manager_ai.behind_aggression - 5)
|
||||
adjusted_running = manager_ai.running + aggression_mod
|
||||
|
||||
if adjusted_running >= 8:
|
||||
this_resp.min_safe = 4
|
||||
elif adjusted_running >= 5:
|
||||
this_resp.min_safe = 7
|
||||
else:
|
||||
this_resp.min_safe = 10
|
||||
|
||||
if this_play.starting_outs == 1:
|
||||
this_resp.min_safe -= 2
|
||||
else:
|
||||
this_resp.min_safe += 2
|
||||
|
||||
self.logger.info(f"tag_from_second response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def check_tag_from_third(self, manager_ai: ManagerAi, game: "Game") -> TagResponse:
|
||||
"""
|
||||
Check if runner on third should tag up on a fly ball.
|
||||
|
||||
Migrated from ManagerAi.tag_from_third() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
|
||||
Returns:
|
||||
TagResponse with tag decision details
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("check_tag_from_third", f"game {game.id}")
|
||||
|
||||
this_resp = TagResponse()
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking tag_from_third")
|
||||
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression_mod = abs(manager_ai.ahead_aggression - 5 if ai_rd > 0 else manager_ai.behind_aggression - 5)
|
||||
adjusted_running = manager_ai.running + aggression_mod
|
||||
|
||||
if adjusted_running >= 8:
|
||||
this_resp.min_safe = 7
|
||||
elif adjusted_running >= 5:
|
||||
this_resp.min_safe = 10
|
||||
else:
|
||||
this_resp.min_safe = 12
|
||||
|
||||
if ai_rd in [-1, 0]:
|
||||
this_resp.min_safe -= 2
|
||||
|
||||
if this_play.starting_outs == 1:
|
||||
this_resp.min_safe -= 2
|
||||
|
||||
self.logger.info(f"tag_from_third response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def decide_throw_target(self, manager_ai: ManagerAi, game: "Game") -> ThrowResponse:
|
||||
"""
|
||||
Decide where to throw on uncapped advances.
|
||||
|
||||
Migrated from ManagerAi.throw_at_uncapped() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
|
||||
Returns:
|
||||
ThrowResponse with throw target decision
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("decide_throw_target", f"game {game.id}")
|
||||
|
||||
this_resp = ThrowResponse()
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking throw_at_uncapped")
|
||||
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression = manager_ai.ahead_aggression if ai_rd > 0 else manager_ai.behind_aggression
|
||||
current_outs = this_play.starting_outs + this_play.outs
|
||||
|
||||
if ai_rd > 5:
|
||||
if manager_ai.ahead_aggression > 5:
|
||||
this_resp.at_trail_runner = True
|
||||
this_resp.trail_max_safe_delta = -4 + current_outs
|
||||
else:
|
||||
this_resp.cutoff = True
|
||||
elif ai_rd > 2:
|
||||
if manager_ai.ahead_aggression > 8:
|
||||
this_resp.at_trail_runner = True
|
||||
this_resp.trail_max_safe_delta = -4 + current_outs
|
||||
elif ai_rd > 0:
|
||||
if manager_ai.ahead_aggression > 8:
|
||||
this_resp.at_trail_runner = True
|
||||
this_resp.trail_max_safe_delta = -6 + current_outs
|
||||
elif ai_rd > -3:
|
||||
if manager_ai.behind_aggression < 5:
|
||||
this_resp.at_trail_runner = True
|
||||
this_resp.trail_max_safe_delta = -6 + current_outs
|
||||
elif ai_rd > -6:
|
||||
if manager_ai.behind_aggression < 5:
|
||||
this_resp.at_trail_runner = True
|
||||
this_resp.trail_max_safe_delta = -4 + current_outs
|
||||
else:
|
||||
if manager_ai.behind_aggression < 5:
|
||||
this_resp.at_trail_runner = True
|
||||
this_resp.trail_max_safe_delta = -4
|
||||
|
||||
self.logger.info(f"throw_at_uncapped response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def decide_runner_advance(
|
||||
self,
|
||||
manager_ai: ManagerAi,
|
||||
game: "Game",
|
||||
lead_base: int,
|
||||
trail_base: int
|
||||
) -> UncappedRunResponse:
|
||||
"""
|
||||
Decide if runners should advance on uncapped situations.
|
||||
|
||||
Migrated from ManagerAi.uncapped_advance() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
lead_base: Base number for lead runner
|
||||
trail_base: Base number for trail runner
|
||||
|
||||
Returns:
|
||||
UncappedRunResponse with advance decisions
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("decide_runner_advance", f"game {game.id}, lead_base {lead_base}, trail_base {trail_base}")
|
||||
|
||||
this_resp = UncappedRunResponse()
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking uncapped_advance")
|
||||
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression = manager_ai.ahead_aggression - 5 if ai_rd > 0 else manager_ai.behind_aggression - 5
|
||||
|
||||
if ai_rd > 4:
|
||||
if lead_base == 4:
|
||||
this_resp.min_safe = 16 - this_play.starting_outs - aggression
|
||||
this_resp.send_trail = True
|
||||
this_resp.trail_min_safe = 10 - aggression - this_play.starting_outs - this_play.outs
|
||||
elif lead_base == 3:
|
||||
this_resp.min_safe = 14 + (this_play.starting_outs * 2) - aggression
|
||||
if this_play.starting_outs + this_play.outs >= 2:
|
||||
this_resp.send_trail = False
|
||||
elif ai_rd > 1 or ai_rd < -2:
|
||||
if lead_base == 4:
|
||||
this_resp.min_safe = 12 - this_play.starting_outs - aggression
|
||||
this_resp.send_trail = True
|
||||
this_resp.trail_min_safe = 10 - aggression - this_play.starting_outs - this_play.outs
|
||||
elif lead_base == 3:
|
||||
this_resp.min_safe = 12 + (this_play.starting_outs * 2) - (aggression * 2)
|
||||
if this_play.starting_outs + this_play.outs >= 2:
|
||||
this_resp.send_trail = False
|
||||
else:
|
||||
if lead_base == 4:
|
||||
this_resp.min_safe = 10 - this_play.starting_outs - aggression
|
||||
this_resp.send_trail = True
|
||||
this_resp.trail_min_safe = 2
|
||||
elif lead_base == 3:
|
||||
this_resp.min_safe = 14 + (this_play.starting_outs * 2) - aggression
|
||||
if this_play.starting_outs + this_play.outs >= 2:
|
||||
this_resp.send_trail = False
|
||||
|
||||
# Bounds checking
|
||||
if this_resp.min_safe > 20:
|
||||
this_resp.min_safe = 20
|
||||
if this_resp.min_safe < 1:
|
||||
this_resp.min_safe = 1
|
||||
if this_resp.trail_min_safe > 20:
|
||||
this_resp.trail_min_safe = 20
|
||||
if this_resp.trail_min_safe < 1:
|
||||
this_resp.trail_min_safe = 1
|
||||
|
||||
self.logger.info(f"uncapped advance response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def set_defensive_alignment(self, manager_ai: ManagerAi, game: "Game") -> DefenseResponse:
|
||||
"""
|
||||
Determine defensive alignment and holds.
|
||||
|
||||
Migrated from ManagerAi.defense_alignment() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
|
||||
Returns:
|
||||
DefenseResponse with defensive decisions
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("set_defensive_alignment", f"game {game.id}")
|
||||
|
||||
this_resp = DefenseResponse()
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking defense_alignment")
|
||||
|
||||
self.logger.info(f"defense_alignment - this_play: {this_play}")
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression = manager_ai.ahead_aggression - 5 if ai_rd > 0 else manager_ai.behind_aggression - 5
|
||||
pitcher_hold = this_play.pitcher.card.pitcherscouting.pitchingcard.hold
|
||||
|
||||
catcher_defense = self.session.exec(
|
||||
select(PositionRating).where(
|
||||
PositionRating.player_id == this_play.catcher.player_id,
|
||||
PositionRating.position == 'C',
|
||||
PositionRating.variant == this_play.catcher.card.variant
|
||||
)
|
||||
).one()
|
||||
catcher_hold = catcher_defense.arm
|
||||
battery_hold = pitcher_hold + catcher_hold
|
||||
|
||||
# Hold decisions
|
||||
if this_play.starting_outs == 2 and this_play.on_base_code > 0:
|
||||
self.logger.info("Checking for holds with 2 outs")
|
||||
if this_play.on_base_code == 1:
|
||||
this_resp.hold_first = True
|
||||
this_resp.ai_note += f"- hold {this_play.on_first.player.name} on 1st\n"
|
||||
elif this_play.on_base_code == 2:
|
||||
this_resp.hold_second = True
|
||||
this_resp.ai_note += f"- hold {this_play.on_second.player.name} on 2nd\n"
|
||||
elif this_play.on_base_code in [4, 7]:
|
||||
this_resp.hold_first = True
|
||||
this_resp.hold_second = True
|
||||
this_resp.ai_note += f"- hold {this_play.on_first.player.name} on 1st\n- hold {this_play.on_second.player.name} on 2nd\n"
|
||||
elif this_play.on_base_code == 5:
|
||||
this_resp.hold_first = True
|
||||
this_resp.ai_note += f"- hold {this_play.on_first.player.name} on first\n"
|
||||
elif this_play.on_base_code == 6:
|
||||
this_resp.hold_second = True
|
||||
this_resp.ai_note += f"- hold {this_play.on_second.player.name} on 2nd\n"
|
||||
elif this_play.on_base_code in [1, 5]:
|
||||
self.logger.info("Checking for hold with runner on first")
|
||||
runner = this_play.on_first.player
|
||||
if (this_play.on_first.card.batterscouting.battingcard.steal_auto and
|
||||
((this_play.on_first.card.batterscouting.battingcard.steal_high + battery_hold) >= (12 - aggression))):
|
||||
this_resp.hold_first = True
|
||||
this_resp.ai_note += f"- hold {runner.name} on 1st\n"
|
||||
elif this_play.on_base_code in [2, 4]:
|
||||
self.logger.info("Checking for hold with runner on second")
|
||||
if (this_play.on_second.card.batterscouting.battingcard.steal_low + max(battery_hold, 5)) >= (14 - aggression):
|
||||
this_resp.hold_second = True
|
||||
this_resp.ai_note += f"- hold {this_play.on_second.player.name} on 2nd\n"
|
||||
|
||||
# Defensive Alignment
|
||||
if this_play.on_third and this_play.starting_outs < 2:
|
||||
if this_play.could_walkoff:
|
||||
this_resp.outfield_in = True
|
||||
this_resp.infield_in = True
|
||||
this_resp.ai_note += "- play the outfield and infield in"
|
||||
elif this_play.on_first and this_play.starting_outs == 1:
|
||||
this_resp.corners_in = True
|
||||
this_resp.ai_note += "- play the corners in\n"
|
||||
elif abs(this_play.away_score - this_play.home_score) <= 3:
|
||||
this_resp.infield_in = True
|
||||
this_resp.ai_note += "- play the whole infield in\n"
|
||||
else:
|
||||
this_resp.corners_in = True
|
||||
this_resp.ai_note += "- play the corners in\n"
|
||||
|
||||
if len(this_resp.ai_note) == 0 and this_play.on_base_code > 0:
|
||||
this_resp.ai_note += "- play straight up\n"
|
||||
|
||||
self.logger.info(f"Defense alignment response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def decide_groundball_running(self, manager_ai: ManagerAi, game: "Game") -> RunResponse:
|
||||
"""
|
||||
Decide if AI should run on groundball.
|
||||
|
||||
Migrated from ManagerAi.gb_decide_run() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
|
||||
Returns:
|
||||
RunResponse with running decision
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("decide_groundball_running", f"game {game.id}")
|
||||
|
||||
this_resp = RunResponse()
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking gb_decide_run")
|
||||
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression = manager_ai.ahead_aggression - 5 if ai_rd > 0 else manager_ai.behind_aggression - 5
|
||||
|
||||
this_resp.min_safe = 15 - aggression # TODO: write this algorithm
|
||||
self.logger.info(f"gb_decide_run response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def decide_groundball_throw(
|
||||
self,
|
||||
manager_ai: ManagerAi,
|
||||
game: "Game",
|
||||
runner_speed: int,
|
||||
defender_range: int
|
||||
) -> ThrowResponse:
|
||||
"""
|
||||
Decide where to throw on groundball with runner.
|
||||
|
||||
Migrated from ManagerAi.gb_decide_throw() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
runner_speed: Speed of the runner
|
||||
defender_range: Range of the fielding defender
|
||||
|
||||
Returns:
|
||||
ThrowResponse with throw decision
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("decide_groundball_throw", f"game {game.id}")
|
||||
|
||||
this_resp = ThrowResponse(at_lead_runner=True)
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking gb_decide_throw")
|
||||
|
||||
ai_rd = this_play.ai_run_diff
|
||||
aggression = manager_ai.ahead_aggression - 5 if ai_rd > 0 else manager_ai.behind_aggression - 5
|
||||
|
||||
if (runner_speed - 4 + defender_range) <= (10 + aggression):
|
||||
this_resp.at_lead_runner = True
|
||||
|
||||
self.logger.info(f"gb_decide_throw response: {this_resp}")
|
||||
return this_resp
|
||||
|
||||
def should_replace_pitcher(self, manager_ai: ManagerAi, game: "Game") -> bool:
|
||||
"""
|
||||
Determine if fatigued pitcher should be replaced.
|
||||
|
||||
Migrated from ManagerAi.replace_pitcher() method.
|
||||
|
||||
Args:
|
||||
manager_ai: ManagerAi configuration
|
||||
game: Current game
|
||||
|
||||
Returns:
|
||||
bool: True if pitcher should be replaced
|
||||
|
||||
Raises:
|
||||
GameException: If no current play found
|
||||
"""
|
||||
self._log_operation("should_replace_pitcher", f"game {game.id}")
|
||||
|
||||
this_play = game.current_play_or_none(self.session)
|
||||
if this_play is None:
|
||||
raise ValueError("No game found while checking replace_pitcher")
|
||||
|
||||
this_pitcher = this_play.pitcher
|
||||
outs = self.session.exec(
|
||||
select(func.sum("Play.outs")).where(
|
||||
"Play.game" == game,
|
||||
"Play.pitcher" == this_pitcher,
|
||||
"Play.complete" == True
|
||||
)
|
||||
).one()
|
||||
self.logger.info(f"Pitcher: {this_pitcher.card.player.name_with_desc} / Outs: {outs}")
|
||||
|
||||
allowed_runners = self.session.exec(
|
||||
select(func.count("Play.id")).where(
|
||||
"Play.game" == game,
|
||||
"Play.pitcher" == this_pitcher,
|
||||
or_("Play.hit" == 1, "Play.bb" == 1)
|
||||
)
|
||||
).one()
|
||||
run_diff = this_play.ai_run_diff
|
||||
|
||||
self.logger.info(f"run diff: {run_diff} / allowed runners: {allowed_runners} / behind aggro: {manager_ai.behind_aggression} / ahead aggro: {manager_ai.ahead_aggression}")
|
||||
self.logger.info(f"this play: {this_play}")
|
||||
|
||||
if this_pitcher.replacing_id is None:
|
||||
# Starter logic
|
||||
pitcher_pow = this_pitcher.card.pitcherscouting.pitchingcard.starter_rating
|
||||
self.logger.info(f"Starter POW: {pitcher_pow}")
|
||||
|
||||
if outs >= pitcher_pow * 3 + 6:
|
||||
self.logger.info("Starter has thrown POW + 3 - being pulled")
|
||||
return True
|
||||
|
||||
elif allowed_runners < 5:
|
||||
self.logger.info(f"Starter is cooking with {allowed_runners} runners allowed - staying in")
|
||||
return False
|
||||
|
||||
elif this_pitcher.is_fatigued and this_play.on_base_code > 1:
|
||||
self.logger.info("Starter is fatigued")
|
||||
return True
|
||||
|
||||
elif (run_diff > 5 or (run_diff > 2 and manager_ai.ahead_aggression > 5)) and (allowed_runners < run_diff or this_play.on_base_code <= 3):
|
||||
self.logger.info(f"AI team has big lead of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif (run_diff > 2 or (run_diff >= 0 and manager_ai.ahead_aggression > 5)) and (allowed_runners < run_diff or this_play.on_base_code <= 1):
|
||||
self.logger.info(f"AI team has lead of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif (run_diff >= 0 or (run_diff >= -2 and manager_ai.behind_aggression > 5)) and (allowed_runners < 5 and this_play.on_base_code <= run_diff):
|
||||
self.logger.info(f"AI team in close game with run diff of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif run_diff >= -3 and manager_ai.behind_aggression > 5 and allowed_runners < 5 and this_play.on_base_code <= 1:
|
||||
self.logger.info(f"AI team is close behind with run diff of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif run_diff <= -5 and this_play.inning_num <= 3:
|
||||
self.logger.info("AI team is way behind and starter is going to wear it - staying in")
|
||||
return False
|
||||
|
||||
else:
|
||||
self.logger.info("AI team found no exceptions - pull starter")
|
||||
return True
|
||||
|
||||
else:
|
||||
# Reliever logic
|
||||
pitcher_pow = this_pitcher.card.pitcherscouting.pitchingcard.relief_rating
|
||||
self.logger.info(f"Reliever POW: {pitcher_pow}")
|
||||
|
||||
if outs >= pitcher_pow * 3 + 3:
|
||||
self.logger.info("Only allow POW + 1 IP - pull reliever")
|
||||
return True
|
||||
|
||||
elif this_pitcher.is_fatigued and this_play.is_new_inning:
|
||||
self.logger.info("Reliever is fatigued to start the inning - pull reliever")
|
||||
return True
|
||||
|
||||
elif (run_diff > 5 or (run_diff > 2 and manager_ai.ahead_aggression > 5)) and (this_play.starting_outs == 2 or allowed_runners <= run_diff or this_play.on_base_code <= 3 or this_play.starting_outs == 2):
|
||||
self.logger.info(f"AI team has big lead of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif (run_diff > 2 or (run_diff >= 0 and manager_ai.ahead_aggression > 5)) and (allowed_runners < run_diff or this_play.on_base_code <= 1 or this_play.starting_outs == 2):
|
||||
self.logger.info(f"AI team has lead of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif (run_diff >= 0 or (run_diff >= -2 and manager_ai.behind_aggression > 5)) and (allowed_runners < 5 or this_play.on_base_code <= run_diff or this_play.starting_outs == 2):
|
||||
self.logger.info(f"AI team in close game with run diff of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif run_diff >= -3 and manager_ai.behind_aggression > 5 and allowed_runners < 5 and this_play.on_base_code <= 1:
|
||||
self.logger.info(f"AI team is close behind with run diff of {run_diff} - staying in")
|
||||
return False
|
||||
|
||||
elif run_diff <= -5 and this_play.starting_outs != 0:
|
||||
self.logger.info("AI team is way behind and reliever is going to wear it - staying in")
|
||||
return False
|
||||
|
||||
else:
|
||||
self.logger.info("AI team found no exceptions - pull reliever")
|
||||
return True
|
||||
@ -82,9 +82,16 @@ def get_ai_service(session: SessionDep):
|
||||
return AIService(session)
|
||||
|
||||
|
||||
def get_ui_service(session: SessionDep):
|
||||
"""Get UIService instance."""
|
||||
from app.services.ui_service import UIService
|
||||
return UIService(session)
|
||||
|
||||
|
||||
# Type aliases for service dependencies
|
||||
GameServiceDep = Annotated[object, Depends(get_game_service)]
|
||||
UserServiceDep = Annotated[object, Depends(get_user_service)]
|
||||
AuthServiceDep = Annotated[object, Depends(get_auth_service)]
|
||||
GameplayServiceDep = Annotated[object, Depends(get_gameplay_service)]
|
||||
AIServiceDep = Annotated[object, Depends(get_ai_service)]
|
||||
UIServiceDep = Annotated[object, Depends(get_ui_service)]
|
||||
51
app/services/ui_service.py
Normal file
51
app/services/ui_service.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""UI Service - handles user interface formatting and display logic.
|
||||
|
||||
This service contains all business logic for formatting data for display
|
||||
that was extracted from models during the migration from Discord app.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from .base_service import BaseService
|
||||
from ..models.team import Team
|
||||
|
||||
|
||||
class UIService(BaseService):
|
||||
"""Service for user interface formatting and display logic."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}')
|
||||
|
||||
def format_team_display(self, team: Team) -> dict:
|
||||
"""Format team for web display (extracted from Discord embed property).
|
||||
|
||||
Args:
|
||||
team: Team model instance
|
||||
|
||||
Returns:
|
||||
dict: Formatted team display data for web UI
|
||||
"""
|
||||
self._log_operation(f"Formatting team display for team {team.id}")
|
||||
|
||||
try:
|
||||
# Constants from original Discord app
|
||||
SBA_COLOR = 'a6ce39'
|
||||
SBA_LOGO = 'https://paper-dynasty.s3.us-east-1.amazonaws.com/static-images/sba-logo.png'
|
||||
|
||||
display_data = {
|
||||
'title': team.lname,
|
||||
'color': team.color if team.color else SBA_COLOR,
|
||||
'footer_text': f'Paper Dynasty Season {team.season}',
|
||||
'footer_icon': SBA_LOGO,
|
||||
'thumbnail': team.logo if team.logo else SBA_LOGO,
|
||||
'team_id': team.id,
|
||||
'abbrev': team.abbrev,
|
||||
'season': team.season
|
||||
}
|
||||
|
||||
self.logger.info(f"Successfully formatted team display for {team.abbrev}")
|
||||
return display_data
|
||||
|
||||
except Exception as e:
|
||||
self._log_error(f"format_team_display for team {team.id}", e)
|
||||
raise
|
||||
@ -18,5 +18,24 @@ services:
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
|
||||
postgres-test:
|
||||
image: postgres:15
|
||||
container_name: pdtest-postgres
|
||||
environment:
|
||||
POSTGRES_USER: paper_dynasty_user
|
||||
POSTGRES_PASSWORD: paper_dynasty_test_password
|
||||
POSTGRES_DB: paper_dynasty_test
|
||||
ports:
|
||||
- "5434:5432"
|
||||
volumes:
|
||||
- postgres_test_data:/var/lib/postgresql/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U paper_dynasty_user -d paper_dynasty_test"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
postgres_test_data:
|
||||
17
pytest.ini
Normal file
17
pytest.ini
Normal file
@ -0,0 +1,17 @@
|
||||
[tool:pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
--verbose
|
||||
markers =
|
||||
unit: Unit tests that mock dependencies
|
||||
integration: Integration tests with database
|
||||
slow: Slow running tests
|
||||
filterwarnings =
|
||||
ignore::sqlalchemy.exc.SAWarning
|
||||
ignore::DeprecationWarning
|
||||
170
tests/README.md
170
tests/README.md
@ -173,10 +173,103 @@ def test_game_data():
|
||||
```
|
||||
|
||||
### Test Database
|
||||
Integration tests use a separate test database:
|
||||
- `DATABASE_TEST_URL` environment variable
|
||||
- Isolated transactions (rollback after each test)
|
||||
- Clean state for each test
|
||||
|
||||
Integration tests use a separate PostgreSQL test database:
|
||||
- **Container**: `pdtest-postgres` on port 5434 (via docker-compose)
|
||||
- **URL**: `postgresql://paper_dynasty_user:paper_dynasty_test_password@localhost:5434/paper_dynasty_test`
|
||||
- **Isolation**: Transaction rollback after each test
|
||||
- **Clean state**: Each test runs in isolation
|
||||
|
||||
#### Database Testing Strategy
|
||||
|
||||
**🚨 CRITICAL: Always Use Centralized Fixtures**
|
||||
|
||||
**✅ CORRECT - Use centralized `db_session` fixture from `conftest.py`:**
|
||||
```python
|
||||
# ✅ Good - uses proper rollback
|
||||
def test_create_team(db_session):
|
||||
team = TeamFactory.create(db_session, name="Test Team")
|
||||
assert team.id is not None
|
||||
```
|
||||
|
||||
**❌ WRONG - Never create custom database fixtures:**
|
||||
```python
|
||||
# ❌ BAD - creates data persistence issues
|
||||
@pytest.fixture
|
||||
def session(test_db):
|
||||
with Session(test_db) as session:
|
||||
yield session # No rollback!
|
||||
```
|
||||
|
||||
**Transaction Rollback Pattern** (Already implemented in `conftest.py`):
|
||||
```python
|
||||
@pytest.fixture
|
||||
def db_session(test_engine):
|
||||
"""Database session with transaction rollback for test isolation."""
|
||||
connection = test_engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = Session(bind=connection)
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
transaction.rollback() # ✅ Automatic cleanup
|
||||
connection.close()
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- ✅ Complete test isolation
|
||||
- ✅ Fast execution (no actual database writes)
|
||||
- ✅ No cleanup required
|
||||
- ✅ Deterministic test results
|
||||
|
||||
**🚨 CRITICAL: Always Use Test Factories**
|
||||
|
||||
**✅ CORRECT - Use factories with unique IDs:**
|
||||
```python
|
||||
# ✅ Good - uses factory with unique ID generation
|
||||
def test_create_cardset(db_session):
|
||||
cardset = CardsetFactory.create(db_session, name="Test Set")
|
||||
assert cardset.id is not None
|
||||
```
|
||||
|
||||
**❌ WRONG - Never use hardcoded IDs:**
|
||||
```python
|
||||
# ❌ BAD - hardcoded IDs cause conflicts
|
||||
def test_create_cardset(db_session):
|
||||
cardset = Cardset(id=1, name="Test Set") # Will conflict!
|
||||
db_session.add(cardset)
|
||||
db_session.commit()
|
||||
```
|
||||
|
||||
**Factory Pattern** (see `tests/factories/`):
|
||||
```python
|
||||
class CardsetFactory:
|
||||
@staticmethod
|
||||
def build(**kwargs):
|
||||
defaults = {
|
||||
'id': generate_unique_id(), # ✅ Unique every time
|
||||
'name': generate_unique_name('Cardset'),
|
||||
'ranked_legal': False
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Cardset(**defaults)
|
||||
|
||||
@staticmethod
|
||||
def create(session, **kwargs):
|
||||
cardset = CardsetFactory.build(**kwargs)
|
||||
session.add(cardset)
|
||||
session.commit()
|
||||
session.refresh(cardset)
|
||||
return cardset
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- ✅ Unique data per test
|
||||
- ✅ No ID conflicts
|
||||
- ✅ Customizable test data
|
||||
- ✅ Readable test code
|
||||
|
||||
## Testing Best Practices
|
||||
|
||||
@ -191,6 +284,75 @@ Integration tests use a separate test database:
|
||||
- **Isolate test state** (no shared mutable state)
|
||||
- **Clean up after tests** (database rollback)
|
||||
|
||||
### 🚨 Test Isolation Requirements
|
||||
|
||||
**MANDATORY for all new tests:**
|
||||
|
||||
1. **Use `db_session` fixture** from `conftest.py` - never create custom session fixtures
|
||||
2. **Use factory classes** for all test data - never hardcode IDs or use static values
|
||||
3. **Import factories** from `tests.factories` package
|
||||
4. **Test in isolation** - each test should work independently
|
||||
|
||||
**Checklist for New Tests:**
|
||||
```python
|
||||
# ✅ Required imports
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
|
||||
# ✅ Required fixture usage
|
||||
def test_something(db_session): # Use db_session, not session
|
||||
pass
|
||||
|
||||
# ✅ Required factory usage
|
||||
team = TeamFactory.create(db_session, name="Custom Name")
|
||||
# NOT: team = Team(id=1, name="Custom Name")
|
||||
|
||||
# ✅ Required test isolation
|
||||
# Each test should be runnable independently and repeatedly
|
||||
```
|
||||
|
||||
### 🚨 Common Anti-Patterns to Avoid
|
||||
|
||||
**❌ Creating Custom Database Fixtures:**
|
||||
```python
|
||||
# DON'T DO THIS - breaks test isolation
|
||||
@pytest.fixture
|
||||
def session(test_db):
|
||||
with Session(test_db) as session:
|
||||
yield session
|
||||
```
|
||||
|
||||
**❌ Using Hardcoded IDs:**
|
||||
```python
|
||||
# DON'T DO THIS - causes primary key conflicts
|
||||
cardset = Cardset(id=1, name="Test")
|
||||
team = Team(id=100, abbrev="TST")
|
||||
```
|
||||
|
||||
**❌ Manual Model Creation:**
|
||||
```python
|
||||
# DON'T DO THIS - creates duplicate and brittle tests
|
||||
def test_something(db_session):
|
||||
cardset = Cardset(
|
||||
id=generate_unique_id(),
|
||||
name="Manual Cardset",
|
||||
ranked_legal=False
|
||||
)
|
||||
db_session.add(cardset)
|
||||
db_session.commit()
|
||||
```
|
||||
|
||||
**✅ Correct Patterns:**
|
||||
```python
|
||||
# DO THIS - uses proper isolation and factories
|
||||
def test_something(db_session):
|
||||
cardset = CardsetFactory.create(
|
||||
db_session,
|
||||
name="Test Cardset",
|
||||
ranked_legal=False
|
||||
)
|
||||
# Test logic here
|
||||
```
|
||||
|
||||
### Coverage Goals
|
||||
- **Services**: 90%+ coverage (core business logic)
|
||||
- **Engine**: 95%+ coverage (critical game mechanics)
|
||||
|
||||
396
tests/TEST_ISOLATION_GUIDE.md
Normal file
396
tests/TEST_ISOLATION_GUIDE.md
Normal file
@ -0,0 +1,396 @@
|
||||
# 🚨 Test Isolation Best Practices Guide
|
||||
|
||||
**CRITICAL: This guide prevents data persistence issues and test conflicts.**
|
||||
|
||||
## The Problem We Solved
|
||||
|
||||
Previously, tests were creating their own database fixtures and using hardcoded IDs, causing:
|
||||
- ❌ Data persistence between test runs
|
||||
- ❌ Primary key conflicts
|
||||
- ❌ Tests depending on execution order
|
||||
- ❌ Intermittent test failures
|
||||
- ❌ Polluted test database
|
||||
|
||||
## The Solution: Centralized Fixtures + Factory Pattern
|
||||
|
||||
### ✅ ALWAYS DO: Use Centralized Database Fixtures
|
||||
|
||||
**Use the `db_session` fixture from `conftest.py`:**
|
||||
|
||||
```python
|
||||
# ✅ CORRECT
|
||||
def test_create_team(db_session):
|
||||
team = TeamFactory.create(db_session, abbrev="LAD")
|
||||
assert team.id is not None
|
||||
```
|
||||
|
||||
**This fixture provides:**
|
||||
- ✅ Automatic transaction rollback after each test
|
||||
- ✅ Complete test isolation
|
||||
- ✅ Fast execution (no actual database writes)
|
||||
- ✅ Deterministic results
|
||||
|
||||
### ❌ NEVER DO: Create Custom Database Fixtures
|
||||
|
||||
```python
|
||||
# ❌ WRONG - Creates data persistence issues
|
||||
@pytest.fixture
|
||||
def session(test_db):
|
||||
with Session(test_db) as session:
|
||||
yield session # No rollback!
|
||||
|
||||
@pytest.fixture
|
||||
def my_custom_session():
|
||||
# Custom session logic
|
||||
pass
|
||||
```
|
||||
|
||||
**Why this is wrong:**
|
||||
- Data persists between tests
|
||||
- No automatic cleanup
|
||||
- Tests interfere with each other
|
||||
- Inconsistent test results
|
||||
|
||||
### ✅ ALWAYS DO: Use Test Factories
|
||||
|
||||
**Use factory classes for all test data:**
|
||||
|
||||
```python
|
||||
# ✅ CORRECT
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
|
||||
def test_team_creation(db_session):
|
||||
team = TeamFactory.create(db_session, abbrev="BOS")
|
||||
assert team.abbrev == "BOS"
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ✅ Unique IDs every time
|
||||
- ✅ No conflicts between tests
|
||||
- ✅ Consistent, valid data
|
||||
- ✅ Customizable per test
|
||||
|
||||
### ❌ NEVER DO: Manual Model Creation with Hardcoded IDs
|
||||
|
||||
```python
|
||||
# ❌ WRONG - Hardcoded IDs cause conflicts
|
||||
def test_bad_team_creation(db_session):
|
||||
team = Team(
|
||||
id=1, # ❌ Will conflict with other tests
|
||||
abbrev="TST",
|
||||
lname="Test Team",
|
||||
# ... many required fields
|
||||
)
|
||||
db_session.add(team)
|
||||
db_session.commit()
|
||||
```
|
||||
|
||||
**Why this is wrong:**
|
||||
- Primary key conflicts between tests
|
||||
- Brittle when test data requirements change
|
||||
- Verbose and hard to maintain
|
||||
- No guarantee of unique data
|
||||
|
||||
## Detailed Implementation Guide
|
||||
|
||||
### 1. Database Session Usage
|
||||
|
||||
**✅ CORRECT Pattern:**
|
||||
```python
|
||||
def test_something(db_session): # Parameter name must be 'db_session'
|
||||
# Create test data using factories
|
||||
team = TeamFactory.create(db_session, abbrev="TEST")
|
||||
|
||||
# Perform test operations
|
||||
result = some_service_operation(team)
|
||||
|
||||
# Make assertions
|
||||
assert result is not None
|
||||
|
||||
# No cleanup needed - automatic rollback
|
||||
```
|
||||
|
||||
**❌ WRONG Patterns:**
|
||||
```python
|
||||
# Don't define custom fixtures
|
||||
@pytest.fixture
|
||||
def session():
|
||||
pass
|
||||
|
||||
# Don't use different parameter names
|
||||
def test_something(custom_session):
|
||||
pass
|
||||
|
||||
# Don't create sessions manually
|
||||
def test_something():
|
||||
with Session(engine) as session:
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Factory Usage Patterns
|
||||
|
||||
**✅ CORRECT Factory Usage:**
|
||||
```python
|
||||
# Basic creation
|
||||
team = TeamFactory.create(db_session, abbrev="LAD")
|
||||
|
||||
# Custom values
|
||||
ai_team = TeamFactory.create(db_session, is_ai=True, wallet=100000)
|
||||
|
||||
# Specialized methods
|
||||
ai_team = TeamFactory.build_ai_team()
|
||||
human_team = TeamFactory.build_human_team()
|
||||
|
||||
# Multiple objects
|
||||
teams = TeamFactory.build_multiple(3, season=9)
|
||||
```
|
||||
|
||||
**❌ WRONG Manual Creation:**
|
||||
```python
|
||||
# Don't create models manually
|
||||
team = Team(id=1, abbrev="TST", ...)
|
||||
|
||||
# Don't use non-unique values
|
||||
team1 = Team(id=100, abbrev="SAME")
|
||||
team2 = Team(id=100, abbrev="SAME") # Conflict!
|
||||
|
||||
# Don't skip required fields
|
||||
team = Team(abbrev="TST") # Missing required fields
|
||||
```
|
||||
|
||||
### 3. Test Structure Template
|
||||
|
||||
**Use this template for all new database tests:**
|
||||
|
||||
```python
|
||||
"""
|
||||
Test module for [functionality].
|
||||
|
||||
Tests [describe what is being tested].
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
from tests.factories.cardset_factory import CardsetFactory
|
||||
# Import other factories as needed
|
||||
|
||||
class TestSomeFunctionality:
|
||||
"""Test [specific functionality]."""
|
||||
|
||||
def test_basic_case(self, db_session):
|
||||
"""Test basic functionality works."""
|
||||
# Arrange - create test data
|
||||
team = TeamFactory.create(db_session, abbrev="TEST")
|
||||
|
||||
# Act - perform operation
|
||||
result = perform_operation(team)
|
||||
|
||||
# Assert - verify results
|
||||
assert result.success is True
|
||||
|
||||
def test_edge_case(self, db_session):
|
||||
"""Test edge case handling."""
|
||||
# Arrange
|
||||
special_team = TeamFactory.create(
|
||||
db_session,
|
||||
is_ai=True,
|
||||
wallet=0 # Edge case: no money
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(InsufficientFundsError):
|
||||
perform_expensive_operation(special_team)
|
||||
|
||||
def test_multiple_objects(self, db_session):
|
||||
"""Test with multiple related objects."""
|
||||
# Arrange
|
||||
teams = TeamFactory.build_multiple(3)
|
||||
cardset = CardsetFactory.create(db_session, ranked_legal=True)
|
||||
|
||||
for team in teams:
|
||||
db_session.add(team)
|
||||
db_session.commit()
|
||||
|
||||
# Act
|
||||
result = operation_with_multiple_teams(teams, cardset)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
```
|
||||
|
||||
## Common Scenarios and Solutions
|
||||
|
||||
### Scenario 1: Testing Team Creation
|
||||
|
||||
**✅ CORRECT:**
|
||||
```python
|
||||
def test_create_team(db_session):
|
||||
team = TeamFactory.create(
|
||||
db_session,
|
||||
abbrev="LAD",
|
||||
lname="Los Angeles Dodgers",
|
||||
wallet=50000
|
||||
)
|
||||
|
||||
assert team.id is not None
|
||||
assert team.abbrev == "LAD"
|
||||
assert team.wallet == 50000
|
||||
```
|
||||
|
||||
**❌ WRONG:**
|
||||
```python
|
||||
def test_create_team(db_session):
|
||||
team = Team(
|
||||
id=1, # Hardcoded ID
|
||||
abbrev="LAD",
|
||||
lname="Los Angeles Dodgers",
|
||||
gmid=100, # More hardcoded values
|
||||
# ... many required fields
|
||||
)
|
||||
db_session.add(team)
|
||||
db_session.commit()
|
||||
```
|
||||
|
||||
### Scenario 2: Testing with Related Objects
|
||||
|
||||
**✅ CORRECT:**
|
||||
```python
|
||||
def test_game_with_teams(db_session):
|
||||
home_team = TeamFactory.create(db_session, abbrev="HOME")
|
||||
away_team = TeamFactory.create(db_session, abbrev="AWAY")
|
||||
cardset = CardsetFactory.create(db_session, ranked_legal=True)
|
||||
|
||||
# Each object has unique ID automatically
|
||||
game = create_game(home_team, away_team, cardset)
|
||||
assert game.home_team_id == home_team.id
|
||||
```
|
||||
|
||||
**❌ WRONG:**
|
||||
```python
|
||||
def test_game_with_teams(db_session):
|
||||
home_team = Team(id=1, abbrev="HOME", ...)
|
||||
away_team = Team(id=2, abbrev="AWAY", ...)
|
||||
# Verbose and error-prone
|
||||
```
|
||||
|
||||
### Scenario 3: Testing AI Behavior
|
||||
|
||||
**✅ CORRECT:**
|
||||
```python
|
||||
def test_ai_decision_making(db_session):
|
||||
aggressive_ai = ManagerAiFactory.create_aggressive(db_session)
|
||||
conservative_ai = ManagerAiFactory.create_conservative(db_session)
|
||||
|
||||
# Test different AI personalities
|
||||
agg_decision = aggressive_ai.make_decision(situation)
|
||||
cons_decision = conservative_ai.make_decision(situation)
|
||||
|
||||
assert agg_decision.risk_level > cons_decision.risk_level
|
||||
```
|
||||
|
||||
**❌ WRONG:**
|
||||
```python
|
||||
def test_ai_decision_making(db_session):
|
||||
ai1 = ManagerAi(id=1, steal=10, running=10, ...)
|
||||
ai2 = ManagerAi(id=2, steal=2, running=2, ...)
|
||||
# Manual setup of complex objects
|
||||
```
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before submitting any test that uses the database, verify:
|
||||
|
||||
### ✅ Fixture Usage
|
||||
- [ ] Uses `db_session` parameter from `conftest.py`
|
||||
- [ ] Does NOT define custom session fixtures
|
||||
- [ ] Does NOT use `session`, `test_db`, or other custom names
|
||||
|
||||
### ✅ Factory Usage
|
||||
- [ ] Imports factories from `tests.factories`
|
||||
- [ ] Uses `Factory.create()` or `Factory.build()` methods
|
||||
- [ ] Does NOT create models with `Model(id=hardcoded_value)`
|
||||
- [ ] Does NOT use static/hardcoded values that could conflict
|
||||
|
||||
### ✅ Test Isolation
|
||||
- [ ] Test can be run independently
|
||||
- [ ] Test can be run multiple times without failure
|
||||
- [ ] Test does not depend on execution order
|
||||
- [ ] Test does not modify shared state
|
||||
|
||||
### ✅ Data Cleanup
|
||||
- [ ] No manual cleanup code needed
|
||||
- [ ] Relies on automatic transaction rollback
|
||||
- [ ] Does not call `session.rollback()` manually
|
||||
|
||||
## Debugging Test Isolation Issues
|
||||
|
||||
### Problem: Tests pass individually but fail when run together
|
||||
|
||||
**Diagnosis:**
|
||||
```bash
|
||||
# Run individual test
|
||||
pytest tests/unit/models/test_team.py::test_create_team -v # ✅ Passes
|
||||
|
||||
# Run all tests
|
||||
pytest tests/unit/models/test_team.py -v # ❌ Fails
|
||||
```
|
||||
|
||||
**Likely Causes:**
|
||||
1. Using hardcoded IDs that conflict
|
||||
2. Not using the `db_session` fixture
|
||||
3. Sharing mutable state between tests
|
||||
4. Custom fixtures without proper cleanup
|
||||
|
||||
**Solution:**
|
||||
1. Check all model creation uses factories
|
||||
2. Verify `db_session` fixture usage
|
||||
3. Ensure unique IDs via `generate_unique_id()`
|
||||
|
||||
### Problem: "duplicate key value violates unique constraint"
|
||||
|
||||
**Error Message:**
|
||||
```
|
||||
IntegrityError: (psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint "team_pkey"
|
||||
DETAIL: Key (id)=(1) already exists.
|
||||
```
|
||||
|
||||
**Cause:** Using hardcoded IDs instead of factory-generated unique IDs
|
||||
|
||||
**Solution:**
|
||||
```python
|
||||
# ❌ WRONG
|
||||
team = Team(id=1, ...)
|
||||
|
||||
# ✅ CORRECT
|
||||
team = TeamFactory.create(db_session, ...)
|
||||
```
|
||||
|
||||
### Problem: Tests find unexpected data
|
||||
|
||||
**Symptom:**
|
||||
```python
|
||||
# Expected 1 cardset, found 8
|
||||
assert len(cardsets) == 1 # Fails: found old data
|
||||
```
|
||||
|
||||
**Cause:** Previous tests didn't use transaction rollback
|
||||
|
||||
**Solution:**
|
||||
1. Clean test database: `TRUNCATE TABLE cardset CASCADE`
|
||||
2. Fix all tests to use `db_session` fixture
|
||||
3. Verify proper transaction rollback
|
||||
|
||||
## Summary: The Two Golden Rules
|
||||
|
||||
### 🥇 Rule #1: Always Use `db_session` Fixture
|
||||
```python
|
||||
def test_anything_with_database(db_session): # ✅ CORRECT
|
||||
pass
|
||||
```
|
||||
|
||||
### 🥇 Rule #2: Always Use Factory Classes
|
||||
```python
|
||||
team = TeamFactory.create(db_session, custom_field="value") # ✅ CORRECT
|
||||
```
|
||||
|
||||
Following these two rules prevents 99% of test isolation issues and ensures reliable, maintainable tests.
|
||||
78
tests/conftest.py
Normal file
78
tests/conftest.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Shared pytest fixtures for Paper Dynasty web app testing.
|
||||
|
||||
Provides database sessions, test data factories, and common testing utilities
|
||||
following the transaction rollback pattern for test isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
# Test Database Configuration
|
||||
TEST_DATABASE_URL = "postgresql://paper_dynasty_user:paper_dynasty_test_password@localhost:5434/paper_dynasty_test"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_engine():
|
||||
"""Create test database engine for the entire test session."""
|
||||
engine = create_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
# Create all tables
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
yield engine
|
||||
|
||||
# Optional: Drop all tables after test session
|
||||
# SQLModel.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(test_engine):
|
||||
"""
|
||||
Create database session with transaction rollback for test isolation.
|
||||
|
||||
This is the primary fixture for database tests. Each test runs in a
|
||||
transaction that is rolled back after the test completes, ensuring
|
||||
complete isolation between tests.
|
||||
"""
|
||||
# Create a connection and start a transaction
|
||||
connection = test_engine.connect()
|
||||
transaction = connection.begin()
|
||||
|
||||
# Create session bound to the connection
|
||||
session = Session(bind=connection)
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Close session and rollback transaction
|
||||
session.close()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_db_session(test_engine):
|
||||
"""
|
||||
Create database session for integration tests that need real commits.
|
||||
|
||||
Use this fixture for tests that specifically need to test commit behavior
|
||||
or cross-transaction functionality. Use sparingly as these tests are slower
|
||||
and require manual cleanup.
|
||||
"""
|
||||
with Session(test_engine) as session:
|
||||
yield session
|
||||
# Manual cleanup would go here if needed
|
||||
|
||||
|
||||
def generate_unique_id():
|
||||
"""Generate unique integer ID for test data."""
|
||||
# Use last 8 digits of uuid4 as integer to avoid conflicts
|
||||
return int(str(uuid4()).replace('-', '')[-8:], 16)
|
||||
|
||||
|
||||
def generate_unique_name(prefix="Test"):
|
||||
"""Generate unique name for test data."""
|
||||
return f"{prefix} {uuid4().hex[:8]}"
|
||||
364
tests/factories/README.md
Normal file
364
tests/factories/README.md
Normal file
@ -0,0 +1,364 @@
|
||||
# Test Factories
|
||||
|
||||
This directory contains factory classes for generating unique, valid test data. Factories are essential for maintaining test isolation and preventing data conflicts between tests.
|
||||
|
||||
## 🚨 CRITICAL: Test Isolation Requirements
|
||||
|
||||
**ALL tests must use these factories instead of manual model creation** to ensure:
|
||||
- ✅ Unique IDs prevent primary key conflicts
|
||||
- ✅ Consistent test data structure
|
||||
- ✅ Isolated test execution
|
||||
- ✅ Deterministic test results
|
||||
|
||||
## Factory Pattern Overview
|
||||
|
||||
Each model has a corresponding factory that follows this pattern:
|
||||
|
||||
```python
|
||||
class ModelFactory:
|
||||
@staticmethod
|
||||
def build(**kwargs):
|
||||
"""Build model instance without saving to database."""
|
||||
defaults = {
|
||||
'id': generate_unique_id(),
|
||||
'field1': 'default_value',
|
||||
'field2': generate_unique_name('Prefix')
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Model(**defaults)
|
||||
|
||||
@staticmethod
|
||||
def create(session, **kwargs):
|
||||
"""Create and save model instance to database."""
|
||||
instance = ModelFactory.build(**kwargs)
|
||||
session.add(instance)
|
||||
session.commit()
|
||||
session.refresh(instance)
|
||||
return instance
|
||||
```
|
||||
|
||||
## Available Factories
|
||||
|
||||
### TeamFactory (`team_factory.py`)
|
||||
|
||||
**Purpose**: Generate Team instances with unique IDs and team data.
|
||||
|
||||
**Basic Usage:**
|
||||
```python
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
|
||||
# Build without saving
|
||||
team = TeamFactory.build(abbrev="LAD", lname="Los Angeles Dodgers")
|
||||
|
||||
# Create and save to database
|
||||
team = TeamFactory.create(db_session, abbrev="LAD", wallet=50000)
|
||||
```
|
||||
|
||||
**Specialized Methods:**
|
||||
```python
|
||||
# Create AI team
|
||||
ai_team = TeamFactory.build_ai_team(abbrev="AI1")
|
||||
|
||||
# Create human team
|
||||
human_team = TeamFactory.build_human_team(abbrev="HUM1")
|
||||
|
||||
# Create multiple teams with unique IDs
|
||||
teams = TeamFactory.build_multiple(3, season=9)
|
||||
```
|
||||
|
||||
**Default Values:**
|
||||
- `id`: Unique generated ID
|
||||
- `abbrev`: "TST"
|
||||
- `lname`: "Test Team"
|
||||
- `wallet`: 25000
|
||||
- `is_ai`: False
|
||||
- All other required fields have sensible defaults
|
||||
|
||||
### CardsetFactory (`cardset_factory.py`)
|
||||
|
||||
**Purpose**: Generate Cardset instances for testing card sets and game configurations.
|
||||
|
||||
**Basic Usage:**
|
||||
```python
|
||||
from tests.factories.cardset_factory import CardsetFactory
|
||||
|
||||
# Build without saving
|
||||
cardset = CardsetFactory.build(name="2024 Season", ranked_legal=True)
|
||||
|
||||
# Create and save to database
|
||||
cardset = CardsetFactory.create(db_session, name="Test Set")
|
||||
```
|
||||
|
||||
**Specialized Methods:**
|
||||
```python
|
||||
# Create ranked legal cardset
|
||||
ranked_set = CardsetFactory.create_ranked(db_session, name="Ranked Set")
|
||||
|
||||
# Create multiple cardsets
|
||||
cardsets = CardsetFactory.create_batch(db_session, 3, ranked_legal=True)
|
||||
```
|
||||
|
||||
**Default Values:**
|
||||
- `id`: Unique generated ID
|
||||
- `name`: "Test Cardset [unique]"
|
||||
- `ranked_legal`: False
|
||||
|
||||
### ManagerAiFactory (`manager_ai_factory.py`)
|
||||
|
||||
**Purpose**: Generate ManagerAi instances for testing AI decision-making.
|
||||
|
||||
**Basic Usage:**
|
||||
```python
|
||||
from tests.factories.manager_ai_factory import ManagerAiFactory
|
||||
|
||||
# Build AI with default settings
|
||||
ai = ManagerAiFactory.build_balanced()
|
||||
|
||||
# Create aggressive AI
|
||||
ai = ManagerAiFactory.create_aggressive(db_session)
|
||||
```
|
||||
|
||||
**Specialized Methods:**
|
||||
```python
|
||||
# Predefined AI types
|
||||
balanced_ai = ManagerAiFactory.build_balanced()
|
||||
aggressive_ai = ManagerAiFactory.build_aggressive()
|
||||
conservative_ai = ManagerAiFactory.build_conservative()
|
||||
|
||||
# Custom AI settings
|
||||
custom_ai = ManagerAiFactory.build(steal=10, running=8, hold=3)
|
||||
```
|
||||
|
||||
## Factory Usage Patterns
|
||||
|
||||
### ✅ Correct Usage Patterns
|
||||
|
||||
**Basic Model Creation:**
|
||||
```python
|
||||
def test_team_creation(db_session):
|
||||
team = TeamFactory.create(db_session, abbrev="BOS")
|
||||
assert team.abbrev == "BOS"
|
||||
assert team.id is not None
|
||||
```
|
||||
|
||||
**Custom Field Values:**
|
||||
```python
|
||||
def test_ai_team_behavior(db_session):
|
||||
ai_team = TeamFactory.create(
|
||||
db_session,
|
||||
is_ai=True,
|
||||
abbrev="AI1",
|
||||
wallet=100000
|
||||
)
|
||||
assert ai_team.is_ai is True
|
||||
```
|
||||
|
||||
**Multiple Related Objects:**
|
||||
```python
|
||||
def test_game_creation(db_session):
|
||||
home_team = TeamFactory.create(db_session, abbrev="HOME")
|
||||
away_team = TeamFactory.create(db_session, abbrev="AWAY")
|
||||
cardset = CardsetFactory.create(db_session, ranked_legal=True)
|
||||
|
||||
# Test game creation with related objects
|
||||
# Game logic here...
|
||||
```
|
||||
|
||||
**Batch Creation:**
|
||||
```python
|
||||
def test_multiple_teams(db_session):
|
||||
teams = TeamFactory.build_multiple(5)
|
||||
for team in teams:
|
||||
db_session.add(team)
|
||||
db_session.commit()
|
||||
|
||||
# All teams have unique IDs
|
||||
ids = [team.id for team in teams]
|
||||
assert len(set(ids)) == 5
|
||||
```
|
||||
|
||||
### ❌ Anti-Patterns to Avoid
|
||||
|
||||
**Manual Model Creation:**
|
||||
```python
|
||||
# DON'T DO THIS - hardcoded IDs cause conflicts
|
||||
def test_bad_pattern(db_session):
|
||||
team = Team(
|
||||
id=1, # ❌ Hardcoded ID
|
||||
abbrev="TST",
|
||||
lname="Test Team",
|
||||
# ... many required fields
|
||||
)
|
||||
```
|
||||
|
||||
**Shared Mutable State:**
|
||||
```python
|
||||
# DON'T DO THIS - shared state between tests
|
||||
SHARED_TEAM = Team(id=999, abbrev="SHARED")
|
||||
|
||||
def test_bad_shared_state(db_session):
|
||||
db_session.add(SHARED_TEAM) # ❌ Modifies shared state
|
||||
```
|
||||
|
||||
**Non-unique Values:**
|
||||
```python
|
||||
# DON'T DO THIS - non-unique values cause conflicts
|
||||
def test_bad_non_unique(db_session):
|
||||
team1 = Team(id=1, abbrev="SAME") # ❌ Same ID
|
||||
team2 = Team(id=1, abbrev="SAME") # ❌ Same ID
|
||||
```
|
||||
|
||||
## Creating New Factories
|
||||
|
||||
When adding new models, create corresponding factories following this template:
|
||||
|
||||
### 1. Create Factory File
|
||||
|
||||
Create `tests/factories/model_name_factory.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
ModelName factory for generating test data.
|
||||
|
||||
Provides methods to create unique, valid ModelName instances for testing
|
||||
without conflicts between test runs.
|
||||
"""
|
||||
|
||||
from app.models.model_name import ModelName
|
||||
from tests.conftest import generate_unique_id, generate_unique_name
|
||||
|
||||
|
||||
class ModelNameFactory:
|
||||
"""Factory for creating ModelName test instances."""
|
||||
|
||||
@staticmethod
|
||||
def build(**kwargs):
|
||||
"""
|
||||
Build a ModelName instance without saving to database.
|
||||
|
||||
Args:
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
ModelName: Configured model instance
|
||||
|
||||
Example:
|
||||
model = ModelNameFactory.build(field="custom_value")
|
||||
"""
|
||||
defaults = {
|
||||
'id': generate_unique_id(),
|
||||
'name': generate_unique_name('ModelName'),
|
||||
# Add all required fields with sensible defaults
|
||||
}
|
||||
|
||||
# Override defaults with provided kwargs
|
||||
defaults.update(kwargs)
|
||||
return ModelName(**defaults)
|
||||
|
||||
@staticmethod
|
||||
def create(session, **kwargs):
|
||||
"""
|
||||
Create and save a ModelName instance to the database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
ModelName: Saved model instance
|
||||
|
||||
Example:
|
||||
model = ModelNameFactory.create(session, field="custom_value")
|
||||
"""
|
||||
instance = ModelNameFactory.build(**kwargs)
|
||||
session.add(instance)
|
||||
session.commit()
|
||||
session.refresh(instance)
|
||||
return instance
|
||||
|
||||
# Add specialized factory methods as needed
|
||||
@staticmethod
|
||||
def build_special_type(**kwargs):
|
||||
"""Build specialized variant of model."""
|
||||
special_defaults = {
|
||||
'special_field': 'special_value'
|
||||
}
|
||||
special_defaults.update(kwargs)
|
||||
return ModelNameFactory.build(**special_defaults)
|
||||
```
|
||||
|
||||
### 2. Update Factory __init__.py
|
||||
|
||||
Add your factory to `tests/factories/__init__.py`:
|
||||
|
||||
```python
|
||||
from .model_name_factory import ModelNameFactory
|
||||
|
||||
__all__ = [
|
||||
# ... existing factories
|
||||
"ModelNameFactory",
|
||||
]
|
||||
```
|
||||
|
||||
### 3. Add Factory Tests
|
||||
|
||||
Create tests for your factory in `tests/unit/factories/test_model_name_factory.py`:
|
||||
|
||||
```python
|
||||
def test_build_creates_valid_instance():
|
||||
model = ModelNameFactory.build()
|
||||
assert model.id is not None
|
||||
assert model.name is not None
|
||||
|
||||
def test_create_saves_to_database(db_session):
|
||||
model = ModelNameFactory.create(db_session)
|
||||
retrieved = db_session.get(ModelName, model.id)
|
||||
assert retrieved is not None
|
||||
|
||||
def test_unique_ids_generated():
|
||||
models = [ModelNameFactory.build() for _ in range(5)]
|
||||
ids = [model.id for model in models]
|
||||
assert len(set(ids)) == 5 # All unique
|
||||
```
|
||||
|
||||
## Helper Functions
|
||||
|
||||
### `generate_unique_id()`
|
||||
Generates unique integer IDs using UUID hex conversion:
|
||||
```python
|
||||
id = generate_unique_id() # Returns: 3847291847
|
||||
```
|
||||
|
||||
### `generate_unique_name(prefix="Test")`
|
||||
Generates unique names with UUID suffix:
|
||||
```python
|
||||
name = generate_unique_name("Team") # Returns: "Team a3b4c5d6"
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use factories** for test data creation
|
||||
2. **Never hardcode IDs** - use `generate_unique_id()`
|
||||
3. **Provide sensible defaults** for all required fields
|
||||
4. **Override only what you need** in tests
|
||||
5. **Use specialized methods** for common patterns
|
||||
6. **Test your factories** to ensure they work correctly
|
||||
7. **Keep factories simple** - complex logic belongs in services
|
||||
|
||||
## Integration with Tests
|
||||
|
||||
**Required imports for new test files:**
|
||||
```python
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
from tests.factories.cardset_factory import CardsetFactory
|
||||
# Import other factories as needed
|
||||
```
|
||||
|
||||
**Required fixture usage:**
|
||||
```python
|
||||
def test_something(db_session): # Use db_session from conftest.py
|
||||
team = TeamFactory.create(db_session, abbrev="TEST")
|
||||
# Test logic here
|
||||
```
|
||||
|
||||
This pattern ensures consistent, isolated, and reliable tests across the entire project.
|
||||
16
tests/factories/__init__.py
Normal file
16
tests/factories/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Test data factories for Paper Dynasty web app.
|
||||
|
||||
Factories generate unique, valid test data to ensure test isolation
|
||||
and prevent conflicts between test runs.
|
||||
"""
|
||||
|
||||
from .cardset_factory import CardsetFactory
|
||||
from .manager_ai_factory import ManagerAiFactory
|
||||
from .team_factory import TeamFactory
|
||||
|
||||
__all__ = [
|
||||
"CardsetFactory",
|
||||
"ManagerAiFactory",
|
||||
"TeamFactory",
|
||||
]
|
||||
116
tests/factories/cardset_factory.py
Normal file
116
tests/factories/cardset_factory.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""
|
||||
Cardset factory for generating test data.
|
||||
|
||||
Provides methods to create unique, valid Cardset instances for testing
|
||||
without conflicts between test runs.
|
||||
"""
|
||||
|
||||
from app.models.cardset import Cardset
|
||||
from tests.conftest import generate_unique_id, generate_unique_name
|
||||
|
||||
|
||||
class CardsetFactory:
|
||||
"""Factory for creating Cardset test instances."""
|
||||
|
||||
@staticmethod
|
||||
def build(**kwargs):
|
||||
"""
|
||||
Build a Cardset instance without saving to database.
|
||||
|
||||
Args:
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
Cardset: Configured cardset instance
|
||||
|
||||
Example:
|
||||
cardset = CardsetFactory.build(name="Custom Name")
|
||||
cardset = CardsetFactory.build(ranked_legal=True)
|
||||
"""
|
||||
defaults = {
|
||||
'id': generate_unique_id(),
|
||||
'name': generate_unique_name("Cardset"),
|
||||
'ranked_legal': False
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Cardset(**defaults)
|
||||
|
||||
@classmethod
|
||||
def create(cls, session, **kwargs):
|
||||
"""
|
||||
Create and save a Cardset instance to database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
Cardset: Saved cardset instance
|
||||
|
||||
Example:
|
||||
cardset = CardsetFactory.create(session, ranked_legal=True)
|
||||
"""
|
||||
cardset = cls.build(**kwargs)
|
||||
session.add(cardset)
|
||||
session.flush() # Get ID without committing
|
||||
return cardset
|
||||
|
||||
@classmethod
|
||||
def build_batch(cls, count, **kwargs):
|
||||
"""
|
||||
Build multiple Cardset instances.
|
||||
|
||||
Args:
|
||||
count: Number of cardsets to create
|
||||
**kwargs: Common field values for all cardsets
|
||||
|
||||
Returns:
|
||||
list[Cardset]: List of cardset instances
|
||||
|
||||
Example:
|
||||
cardsets = CardsetFactory.build_batch(3, ranked_legal=True)
|
||||
"""
|
||||
return [cls.build(**kwargs) for _ in range(count)]
|
||||
|
||||
@classmethod
|
||||
def create_batch(cls, session, count, **kwargs):
|
||||
"""
|
||||
Create and save multiple Cardset instances.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
count: Number of cardsets to create
|
||||
**kwargs: Common field values for all cardsets
|
||||
|
||||
Returns:
|
||||
list[Cardset]: List of saved cardset instances
|
||||
|
||||
Example:
|
||||
cardsets = CardsetFactory.create_batch(session, 3, ranked_legal=True)
|
||||
"""
|
||||
cardsets = cls.build_batch(count, **kwargs)
|
||||
session.add_all(cardsets)
|
||||
session.flush()
|
||||
return cardsets
|
||||
|
||||
# Preset factories for common scenarios
|
||||
@classmethod
|
||||
def build_ranked_legal(cls, **kwargs):
|
||||
"""Build a ranked legal cardset."""
|
||||
defaults = {'ranked_legal': True, 'name': generate_unique_name("Ranked Set")}
|
||||
defaults.update(kwargs)
|
||||
return cls.build(**defaults)
|
||||
|
||||
@classmethod
|
||||
def build_casual(cls, **kwargs):
|
||||
"""Build a casual (non-ranked) cardset."""
|
||||
defaults = {'ranked_legal': False, 'name': generate_unique_name("Casual Set")}
|
||||
defaults.update(kwargs)
|
||||
return cls.build(**defaults)
|
||||
|
||||
@classmethod
|
||||
def build_historic(cls, **kwargs):
|
||||
"""Build a historic cardset."""
|
||||
defaults = {'ranked_legal': False, 'name': generate_unique_name("Historic Set")}
|
||||
defaults.update(kwargs)
|
||||
return cls.build(**defaults)
|
||||
154
tests/factories/manager_ai_factory.py
Normal file
154
tests/factories/manager_ai_factory.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""
|
||||
ManagerAi factory for generating test data.
|
||||
|
||||
Provides methods to create unique, valid ManagerAi instances for testing
|
||||
without conflicts between test runs.
|
||||
"""
|
||||
|
||||
from app.models.manager_ai import ManagerAi
|
||||
from tests.conftest import generate_unique_name
|
||||
|
||||
|
||||
class ManagerAiFactory:
|
||||
"""Factory for creating ManagerAi test instances."""
|
||||
|
||||
@staticmethod
|
||||
def build(**kwargs):
|
||||
"""
|
||||
Build a ManagerAi instance without saving to database.
|
||||
|
||||
Args:
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
ManagerAi: Configured manager AI instance
|
||||
|
||||
Example:
|
||||
ai = ManagerAiFactory.build(name="Test AI")
|
||||
ai = ManagerAiFactory.build(steal=10, running=8)
|
||||
"""
|
||||
defaults = {
|
||||
'name': generate_unique_name("AI Manager"),
|
||||
'steal': 5,
|
||||
'running': 5,
|
||||
'hold': 5,
|
||||
'catcher_throw': 5,
|
||||
'uncapped_home': 5,
|
||||
'uncapped_third': 5,
|
||||
'uncapped_trail': 5,
|
||||
'bullpen_matchup': 5,
|
||||
'behind_aggression': 5,
|
||||
'ahead_aggression': 5,
|
||||
'decide_throw': 5
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ManagerAi(**defaults)
|
||||
|
||||
@classmethod
|
||||
def create(cls, session, **kwargs):
|
||||
"""
|
||||
Create and save a ManagerAi instance to database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
ManagerAi: Saved manager AI instance
|
||||
|
||||
Example:
|
||||
ai = ManagerAiFactory.create(session, steal=10)
|
||||
"""
|
||||
ai = cls.build(**kwargs)
|
||||
session.add(ai)
|
||||
session.flush() # Get ID without committing
|
||||
return ai
|
||||
|
||||
@classmethod
|
||||
def build_batch(cls, count, **kwargs):
|
||||
"""
|
||||
Build multiple ManagerAi instances.
|
||||
|
||||
Args:
|
||||
count: Number of AIs to create
|
||||
**kwargs: Common field values for all AIs
|
||||
|
||||
Returns:
|
||||
list[ManagerAi]: List of AI instances
|
||||
|
||||
Example:
|
||||
ais = ManagerAiFactory.build_batch(3, steal=8)
|
||||
"""
|
||||
return [cls.build(**kwargs) for _ in range(count)]
|
||||
|
||||
@classmethod
|
||||
def create_batch(cls, session, count, **kwargs):
|
||||
"""
|
||||
Create and save multiple ManagerAi instances.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
count: Number of AIs to create
|
||||
**kwargs: Common field values for all AIs
|
||||
|
||||
Returns:
|
||||
list[ManagerAi]: List of saved AI instances
|
||||
|
||||
Example:
|
||||
ais = ManagerAiFactory.create_batch(session, 3, steal=8)
|
||||
"""
|
||||
ais = cls.build_batch(count, **kwargs)
|
||||
session.add_all(ais)
|
||||
session.flush()
|
||||
return ais
|
||||
|
||||
# Preset factories for common AI configurations
|
||||
@classmethod
|
||||
def build_balanced(cls, **kwargs):
|
||||
"""Build a balanced AI (all 5s)."""
|
||||
defaults = {
|
||||
'name': generate_unique_name("Balanced AI"),
|
||||
# All defaults are already 5
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return cls.build(**defaults)
|
||||
|
||||
@classmethod
|
||||
def build_aggressive(cls, **kwargs):
|
||||
"""Build an aggressive AI (YOLO preset)."""
|
||||
defaults = {
|
||||
'name': generate_unique_name("Aggressive AI"),
|
||||
'steal': 10,
|
||||
'running': 10,
|
||||
'hold': 5,
|
||||
'catcher_throw': 10,
|
||||
'uncapped_home': 10,
|
||||
'uncapped_third': 10,
|
||||
'uncapped_trail': 10,
|
||||
'bullpen_matchup': 3,
|
||||
'behind_aggression': 10,
|
||||
'ahead_aggression': 10,
|
||||
'decide_throw': 10
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return cls.build(**defaults)
|
||||
|
||||
@classmethod
|
||||
def build_conservative(cls, **kwargs):
|
||||
"""Build a conservative AI (Safe preset)."""
|
||||
defaults = {
|
||||
'name': generate_unique_name("Conservative AI"),
|
||||
'steal': 3,
|
||||
'running': 3,
|
||||
'hold': 8,
|
||||
'catcher_throw': 5,
|
||||
'uncapped_home': 5,
|
||||
'uncapped_third': 3,
|
||||
'uncapped_trail': 5,
|
||||
'bullpen_matchup': 8,
|
||||
'behind_aggression': 5,
|
||||
'ahead_aggression': 1,
|
||||
'decide_throw': 1
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return cls.build(**defaults)
|
||||
138
tests/factories/team_factory.py
Normal file
138
tests/factories/team_factory.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""
|
||||
Team factory for generating test data.
|
||||
|
||||
Provides methods to create unique, valid Team instances for testing
|
||||
without conflicts between test runs.
|
||||
"""
|
||||
|
||||
from app.models.team import Team
|
||||
from tests.conftest import generate_unique_id, generate_unique_name
|
||||
|
||||
|
||||
class TeamFactory:
|
||||
"""Factory for creating Team test instances."""
|
||||
|
||||
@staticmethod
|
||||
def build(**kwargs):
|
||||
"""
|
||||
Build a Team instance without saving to database.
|
||||
|
||||
Args:
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
Team: Configured team instance
|
||||
|
||||
Example:
|
||||
team = TeamFactory.build(abbrev="LAD")
|
||||
team = TeamFactory.build(is_ai=True, wallet=50000)
|
||||
"""
|
||||
defaults = {
|
||||
'id': generate_unique_id(),
|
||||
'abbrev': 'TST',
|
||||
'sname': 'Test',
|
||||
'lname': 'Test Team',
|
||||
'gmid': generate_unique_id(),
|
||||
'gmname': 'Test GM',
|
||||
'gsheet': 'test-sheet-url',
|
||||
'wallet': 25000,
|
||||
'team_value': 100000,
|
||||
'collection_value': 75000,
|
||||
'logo': 'https://example.com/test-logo.png',
|
||||
'color': 'ff0000',
|
||||
'season': 9,
|
||||
'career': 1,
|
||||
'ranking': 50,
|
||||
'has_guide': False,
|
||||
'is_ai': False,
|
||||
}
|
||||
|
||||
# Override defaults with provided kwargs
|
||||
defaults.update(kwargs)
|
||||
return Team(**defaults)
|
||||
|
||||
@staticmethod
|
||||
def create(session, **kwargs):
|
||||
"""
|
||||
Create and save a Team instance to the database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
Team: Saved team instance
|
||||
|
||||
Example:
|
||||
team = TeamFactory.create(session, abbrev="LAD")
|
||||
"""
|
||||
team = TeamFactory.build(**kwargs)
|
||||
session.add(team)
|
||||
session.commit()
|
||||
session.refresh(team)
|
||||
return team
|
||||
|
||||
@staticmethod
|
||||
def build_ai_team(**kwargs):
|
||||
"""
|
||||
Build an AI team with appropriate defaults.
|
||||
|
||||
Args:
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
Team: AI team instance
|
||||
"""
|
||||
ai_defaults = {
|
||||
'is_ai': True,
|
||||
'abbrev': 'AI',
|
||||
'lname': 'AI Team',
|
||||
'gmname': 'AI Manager',
|
||||
}
|
||||
ai_defaults.update(kwargs)
|
||||
return TeamFactory.build(**ai_defaults)
|
||||
|
||||
@staticmethod
|
||||
def build_human_team(**kwargs):
|
||||
"""
|
||||
Build a human team with appropriate defaults.
|
||||
|
||||
Args:
|
||||
**kwargs: Override default field values
|
||||
|
||||
Returns:
|
||||
Team: Human team instance
|
||||
"""
|
||||
human_defaults = {
|
||||
'is_ai': False,
|
||||
'abbrev': 'HUM',
|
||||
'lname': 'Human Team',
|
||||
'gmname': 'Human Manager',
|
||||
'wallet': 50000,
|
||||
}
|
||||
human_defaults.update(kwargs)
|
||||
return TeamFactory.build(**human_defaults)
|
||||
|
||||
@staticmethod
|
||||
def build_multiple(count=3, **kwargs):
|
||||
"""
|
||||
Build multiple Team instances with unique IDs.
|
||||
|
||||
Args:
|
||||
count: Number of teams to create
|
||||
**kwargs: Base field values for all teams
|
||||
|
||||
Returns:
|
||||
list[Team]: List of team instances
|
||||
"""
|
||||
teams = []
|
||||
for i in range(count):
|
||||
team_kwargs = kwargs.copy()
|
||||
team_kwargs['id'] = generate_unique_id()
|
||||
team_kwargs['gmid'] = generate_unique_id()
|
||||
if 'abbrev' not in team_kwargs:
|
||||
team_kwargs['abbrev'] = f'T{i+1:02d}'
|
||||
if 'lname' not in team_kwargs:
|
||||
team_kwargs['lname'] = f'Team {i+1}'
|
||||
teams.append(TeamFactory.build(**team_kwargs))
|
||||
return teams
|
||||
243
tests/unit/models/test_cardset.py
Normal file
243
tests/unit/models/test_cardset.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""
|
||||
Unit tests for Cardset model.
|
||||
|
||||
Tests data validation, field constraints, and model behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, SQLModel, create_engine, text
|
||||
|
||||
from app.models.cardset import Cardset, CardsetBase
|
||||
from tests.factories.cardset_factory import CardsetFactory
|
||||
|
||||
|
||||
# Using centralized fixtures from conftest.py for proper test isolation
|
||||
|
||||
|
||||
class TestCardsetBase:
|
||||
"""Test CardsetBase model validation."""
|
||||
|
||||
def test_create_with_defaults(self):
|
||||
"""Test creating Cardset with default values."""
|
||||
cardset = CardsetBase(name="2024 Season")
|
||||
|
||||
assert cardset.name == "2024 Season"
|
||||
assert cardset.ranked_legal is False # Default value
|
||||
|
||||
def test_create_with_custom_values(self):
|
||||
"""Test creating Cardset with custom values."""
|
||||
cardset = CardsetBase(
|
||||
name="2023 Season",
|
||||
ranked_legal=True
|
||||
)
|
||||
|
||||
assert cardset.name == "2023 Season"
|
||||
assert cardset.ranked_legal is True
|
||||
|
||||
def test_create_with_id(self):
|
||||
"""Test creating Cardset with explicit ID."""
|
||||
cardset = CardsetBase(
|
||||
id=100,
|
||||
name="Historic Set",
|
||||
ranked_legal=False
|
||||
)
|
||||
|
||||
assert cardset.id == 100
|
||||
assert cardset.name == "Historic Set"
|
||||
assert cardset.ranked_legal is False
|
||||
|
||||
def test_required_name_field(self):
|
||||
"""Test that name field is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CardsetBase()
|
||||
assert "Field required" in str(exc_info.value)
|
||||
|
||||
def test_field_descriptions(self):
|
||||
"""Test that field descriptions are properly set."""
|
||||
# Access field descriptions through the model class using Pydantic v2
|
||||
fields = CardsetBase.model_fields
|
||||
assert "Name of the card set" in str(fields['name'])
|
||||
assert "Whether this cardset is legal for ranked play" in str(fields['ranked_legal'])
|
||||
|
||||
|
||||
class TestCardset:
|
||||
"""Test Cardset table model."""
|
||||
|
||||
def test_create_and_save(self, db_session):
|
||||
"""Test creating and saving Cardset to database."""
|
||||
cardset = CardsetFactory.create(
|
||||
db_session,
|
||||
name="2024 Season",
|
||||
ranked_legal=True
|
||||
)
|
||||
|
||||
assert cardset.id is not None
|
||||
assert cardset.name == "2024 Season"
|
||||
assert cardset.ranked_legal is True
|
||||
|
||||
def test_retrieve_from_database(self, db_session):
|
||||
"""Test retrieving Cardset from database."""
|
||||
# Create and save
|
||||
cardset = CardsetFactory.create(
|
||||
db_session,
|
||||
name="Test Retrieval Set",
|
||||
ranked_legal=False
|
||||
)
|
||||
|
||||
# Retrieve
|
||||
retrieved = db_session.get(Cardset, cardset.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Test Retrieval Set"
|
||||
assert retrieved.ranked_legal is False
|
||||
|
||||
def test_update_values(self, db_session):
|
||||
"""Test updating Cardset values."""
|
||||
cardset = Cardset(id=3, name="Update Test")
|
||||
db_session.add(cardset)
|
||||
db_session.commit()
|
||||
|
||||
# Update values
|
||||
cardset.ranked_legal = True
|
||||
cardset.name = "Updated Name"
|
||||
db_session.commit()
|
||||
|
||||
# Verify updates
|
||||
db_session.refresh(cardset)
|
||||
assert cardset.name == "Updated Name"
|
||||
assert cardset.ranked_legal is True
|
||||
|
||||
def test_multiple_instances(self, db_session):
|
||||
"""Test creating multiple Cardset instances."""
|
||||
cardset1 = Cardset(id=10, name="Set A", ranked_legal=True)
|
||||
cardset2 = Cardset(id=11, name="Set B", ranked_legal=False)
|
||||
cardset3 = Cardset(id=12, name="Set C", ranked_legal=True)
|
||||
|
||||
db_session.add_all([cardset1, cardset2, cardset3])
|
||||
db_session.commit()
|
||||
|
||||
# Verify all saved with correct values
|
||||
assert cardset1.name == "Set A"
|
||||
assert cardset1.ranked_legal is True
|
||||
assert cardset2.name == "Set B"
|
||||
assert cardset2.ranked_legal is False
|
||||
assert cardset3.name == "Set C"
|
||||
assert cardset3.ranked_legal is True
|
||||
|
||||
def test_name_is_indexed(self):
|
||||
"""Test that name field has index configuration."""
|
||||
fields = Cardset.model_fields
|
||||
name_field = fields['name']
|
||||
# Check if field has index configuration
|
||||
assert hasattr(name_field, 'json_schema_extra') or 'index' in str(name_field)
|
||||
|
||||
|
||||
class TestCardsetBusinessScenarios:
|
||||
"""Test real-world Cardset usage scenarios."""
|
||||
|
||||
def test_ranked_legal_cardsets(self, db_session):
|
||||
"""Test filtering for ranked legal cardsets."""
|
||||
# Create multiple cardsets using factory
|
||||
ranked_set = CardsetFactory.create(db_session, name="2024 Ranked", ranked_legal=True)
|
||||
casual_set = CardsetFactory.create(db_session, name="2024 Casual", ranked_legal=False)
|
||||
historic_set = CardsetFactory.create(db_session, name="Historic Collection", ranked_legal=False)
|
||||
|
||||
# Query for ranked legal sets (would be done in service layer)
|
||||
from sqlmodel import select
|
||||
ranked_cardsets = db_session.exec(
|
||||
select(Cardset).where(Cardset.ranked_legal == True)
|
||||
).all()
|
||||
|
||||
assert len(ranked_cardsets) == 1
|
||||
assert ranked_cardsets[0].name == "2024 Ranked"
|
||||
|
||||
def test_cardset_naming_conventions(self, db_session):
|
||||
"""Test various cardset naming scenarios."""
|
||||
cardsets = [
|
||||
Cardset(id=30, name="2024 Season", ranked_legal=True),
|
||||
Cardset(id=31, name="2023 Season", ranked_legal=False),
|
||||
Cardset(id=32, name="Historic Collection", ranked_legal=False),
|
||||
Cardset(id=33, name="Special Event - All-Stars", ranked_legal=True),
|
||||
Cardset(id=34, name="Beta Test Set", ranked_legal=False),
|
||||
]
|
||||
|
||||
db_session.add_all(cardsets)
|
||||
db_session.commit()
|
||||
|
||||
# Verify all names are preserved correctly
|
||||
for cardset in cardsets:
|
||||
db_session.refresh(cardset)
|
||||
# Names should be preserved exactly as entered
|
||||
assert len(cardset.name) > 0
|
||||
assert cardset.name in [
|
||||
"2024 Season", "2023 Season", "Historic Collection",
|
||||
"Special Event - All-Stars", "Beta Test Set"
|
||||
]
|
||||
|
||||
def test_default_ranked_legal_behavior(self, db_session):
|
||||
"""Test that cardsets default to not ranked legal."""
|
||||
cardset = Cardset(id=40, name="Default Test")
|
||||
db_session.add(cardset)
|
||||
db_session.commit()
|
||||
db_session.refresh(cardset)
|
||||
|
||||
# Should default to False
|
||||
assert cardset.ranked_legal is False
|
||||
|
||||
def test_explicit_id_assignment(self, db_session):
|
||||
"""Test that IDs can be explicitly assigned (not auto-increment)."""
|
||||
# Based on Discord app model, ID is not auto-increment
|
||||
cardset1 = Cardset(id=1000, name="High ID Set")
|
||||
cardset2 = Cardset(id=2000, name="Another High ID Set")
|
||||
|
||||
db_session.add_all([cardset1, cardset2])
|
||||
db_session.commit()
|
||||
|
||||
assert cardset1.id == 1000
|
||||
assert cardset2.id == 2000
|
||||
|
||||
def test_unique_id_constraint(self, db_session):
|
||||
"""Test that duplicate IDs are not allowed."""
|
||||
cardset1 = Cardset(id=500, name="First Set")
|
||||
cardset2 = Cardset(id=500, name="Duplicate ID Set")
|
||||
|
||||
db_session.add(cardset1)
|
||||
db_session.commit()
|
||||
|
||||
# Adding second cardset with same ID should fail
|
||||
db_session.add(cardset2)
|
||||
with pytest.raises(Exception): # SQLAlchemy will raise an IntegrityError
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestCardsetDataIntegrity:
|
||||
"""Test data integrity and validation."""
|
||||
|
||||
def test_empty_name_not_allowed(self):
|
||||
"""Test that empty name is not allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
CardsetBase(name="")
|
||||
|
||||
def test_none_name_not_allowed(self):
|
||||
"""Test that None name is not allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
CardsetBase(name=None)
|
||||
|
||||
def test_boolean_validation_for_ranked_legal(self):
|
||||
"""Test that ranked_legal field only accepts boolean values."""
|
||||
# Valid boolean values
|
||||
cardset_true = CardsetBase(name="Test", ranked_legal=True)
|
||||
cardset_false = CardsetBase(name="Test", ranked_legal=False)
|
||||
|
||||
assert cardset_true.ranked_legal is True
|
||||
assert cardset_false.ranked_legal is False
|
||||
|
||||
# Invalid values should be coerced or raise validation error
|
||||
with pytest.raises(ValidationError):
|
||||
CardsetBase(name="Test", ranked_legal="invalid")
|
||||
|
||||
def test_id_field_accepts_none(self):
|
||||
"""Test that ID field can be None (for cases where ID isn't known yet)."""
|
||||
cardset = CardsetBase(name="No ID Set", id=None)
|
||||
assert cardset.id is None
|
||||
assert cardset.name == "No ID Set"
|
||||
276
tests/unit/models/test_cardset_proper.py
Normal file
276
tests/unit/models/test_cardset_proper.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
Unit tests for Cardset model using proper testing patterns.
|
||||
|
||||
Tests data validation, field constraints, and model behavior with
|
||||
transaction rollback for test isolation and factories for unique data.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import select
|
||||
|
||||
from app.models.cardset import Cardset, CardsetBase
|
||||
from tests.factories import CardsetFactory
|
||||
|
||||
|
||||
class TestCardsetBaseValidation:
|
||||
"""Test CardsetBase model validation (no database needed)."""
|
||||
|
||||
def test_create_with_defaults(self):
|
||||
"""Test creating Cardset with default values."""
|
||||
cardset = CardsetFactory.build()
|
||||
|
||||
assert cardset.name.startswith("Cardset")
|
||||
assert cardset.ranked_legal is False
|
||||
assert cardset.id is not None
|
||||
|
||||
def test_create_with_custom_values(self):
|
||||
"""Test creating Cardset with custom values."""
|
||||
cardset = CardsetFactory.build(
|
||||
name="Custom Season",
|
||||
ranked_legal=True
|
||||
)
|
||||
|
||||
assert cardset.name == "Custom Season"
|
||||
assert cardset.ranked_legal is True
|
||||
|
||||
def test_create_with_explicit_id(self):
|
||||
"""Test creating Cardset with explicit ID."""
|
||||
cardset = CardsetFactory.build(
|
||||
id=12345,
|
||||
name="Specific ID Set"
|
||||
)
|
||||
|
||||
assert cardset.id == 12345
|
||||
assert cardset.name == "Specific ID Set"
|
||||
|
||||
def test_required_name_field(self):
|
||||
"""Test that name field is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CardsetBase()
|
||||
assert "Field required" in str(exc_info.value)
|
||||
|
||||
def test_empty_name_validation(self):
|
||||
"""Test that empty name is not allowed."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CardsetBase(name="")
|
||||
assert "Name cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_whitespace_only_name_validation(self):
|
||||
"""Test that whitespace-only name is not allowed."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CardsetBase(name=" ")
|
||||
assert "Name cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_none_name_not_allowed(self):
|
||||
"""Test that None name is not allowed."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CardsetBase(name=None)
|
||||
assert "Input should be a valid string" in str(exc_info.value)
|
||||
|
||||
def test_boolean_validation_for_ranked_legal(self):
|
||||
"""Test that ranked_legal field validates boolean values."""
|
||||
# Valid boolean values
|
||||
cardset_true = CardsetFactory.build(ranked_legal=True)
|
||||
cardset_false = CardsetFactory.build(ranked_legal=False)
|
||||
|
||||
assert cardset_true.ranked_legal is True
|
||||
assert cardset_false.ranked_legal is False
|
||||
|
||||
def test_field_descriptions(self):
|
||||
"""Test that field descriptions are properly set."""
|
||||
fields = CardsetBase.model_fields
|
||||
assert "Name of the card set" in str(fields['name'])
|
||||
assert "Whether this cardset is legal for ranked play" in str(fields['ranked_legal'])
|
||||
|
||||
def test_id_field_accepts_none(self):
|
||||
"""Test that ID field can be None."""
|
||||
cardset = CardsetFactory.build(id=None)
|
||||
assert cardset.id is None
|
||||
|
||||
|
||||
class TestCardsetDatabaseOperations:
|
||||
"""Test Cardset database operations with transaction rollback."""
|
||||
|
||||
def test_create_and_save(self, db_session):
|
||||
"""Test creating and saving Cardset to database."""
|
||||
cardset = CardsetFactory.create(
|
||||
db_session,
|
||||
name="Database Test Set",
|
||||
ranked_legal=True
|
||||
)
|
||||
|
||||
assert cardset.id is not None
|
||||
assert cardset.name == "Database Test Set"
|
||||
assert cardset.ranked_legal is True
|
||||
|
||||
def test_retrieve_from_database(self, db_session):
|
||||
"""Test retrieving Cardset from database."""
|
||||
# Create and save
|
||||
original = CardsetFactory.create(
|
||||
db_session,
|
||||
name="Retrieval Test",
|
||||
ranked_legal=False
|
||||
)
|
||||
|
||||
# Retrieve by ID
|
||||
retrieved = db_session.get(Cardset, original.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Retrieval Test"
|
||||
assert retrieved.ranked_legal is False
|
||||
assert retrieved.id == original.id
|
||||
|
||||
def test_update_values(self, db_session):
|
||||
"""Test updating Cardset values."""
|
||||
cardset = CardsetFactory.create(db_session, name="Update Test")
|
||||
|
||||
# Update values
|
||||
cardset.ranked_legal = True
|
||||
cardset.name = "Updated Name"
|
||||
db_session.flush()
|
||||
|
||||
# Verify updates (no commit needed due to rollback)
|
||||
assert cardset.name == "Updated Name"
|
||||
assert cardset.ranked_legal is True
|
||||
|
||||
def test_multiple_instances(self, db_session):
|
||||
"""Test creating multiple Cardset instances."""
|
||||
cardsets = CardsetFactory.create_batch(
|
||||
db_session, 3,
|
||||
ranked_legal=True
|
||||
)
|
||||
|
||||
# Verify all saved with unique IDs
|
||||
ids = [c.id for c in cardsets]
|
||||
assert len(set(ids)) == 3 # All IDs are unique
|
||||
assert all(c.ranked_legal is True for c in cardsets)
|
||||
|
||||
def test_query_operations(self, db_session):
|
||||
"""Test querying Cardset records."""
|
||||
# Create test data
|
||||
ranked_sets = CardsetFactory.create_batch(
|
||||
db_session, 2,
|
||||
ranked_legal=True
|
||||
)
|
||||
casual_sets = CardsetFactory.create_batch(
|
||||
db_session, 3,
|
||||
ranked_legal=False
|
||||
)
|
||||
|
||||
# Query for ranked legal sets
|
||||
ranked_results = db_session.exec(
|
||||
select(Cardset).where(Cardset.ranked_legal == True)
|
||||
).all()
|
||||
|
||||
# Query for casual sets
|
||||
casual_results = db_session.exec(
|
||||
select(Cardset).where(Cardset.ranked_legal == False)
|
||||
).all()
|
||||
|
||||
assert len(ranked_results) == 2
|
||||
assert len(casual_results) == 3
|
||||
assert all(c.ranked_legal is True for c in ranked_results)
|
||||
assert all(c.ranked_legal is False for c in casual_results)
|
||||
|
||||
|
||||
class TestCardsetFactoryPresets:
|
||||
"""Test CardsetFactory preset methods."""
|
||||
|
||||
def test_ranked_legal_preset(self, db_session):
|
||||
"""Test ranked legal cardset factory."""
|
||||
cardset = CardsetFactory.build_ranked_legal()
|
||||
|
||||
assert cardset.ranked_legal is True
|
||||
assert "Ranked Set" in cardset.name
|
||||
|
||||
def test_casual_preset(self, db_session):
|
||||
"""Test casual cardset factory."""
|
||||
cardset = CardsetFactory.build_casual()
|
||||
|
||||
assert cardset.ranked_legal is False
|
||||
assert "Casual Set" in cardset.name
|
||||
|
||||
def test_historic_preset(self, db_session):
|
||||
"""Test historic cardset factory."""
|
||||
cardset = CardsetFactory.build_historic()
|
||||
|
||||
assert cardset.ranked_legal is False
|
||||
assert "Historic Set" in cardset.name
|
||||
|
||||
def test_custom_preset_override(self, db_session):
|
||||
"""Test that preset defaults can be overridden."""
|
||||
cardset = CardsetFactory.build_ranked_legal(
|
||||
name="Custom Ranked Set",
|
||||
ranked_legal=False # Override the preset
|
||||
)
|
||||
|
||||
assert cardset.name == "Custom Ranked Set"
|
||||
assert cardset.ranked_legal is False
|
||||
|
||||
|
||||
class TestCardsetBusinessScenarios:
|
||||
"""Test real-world Cardset usage scenarios."""
|
||||
|
||||
def test_unique_naming_across_tests(self, db_session):
|
||||
"""Test that each test gets unique cardset names."""
|
||||
cardset1 = CardsetFactory.create(db_session)
|
||||
cardset2 = CardsetFactory.create(db_session)
|
||||
|
||||
# Names should be different due to unique generation
|
||||
assert cardset1.name != cardset2.name
|
||||
assert cardset1.id != cardset2.id
|
||||
|
||||
def test_batch_creation_uniqueness(self, db_session):
|
||||
"""Test that batch creation produces unique items."""
|
||||
cardsets = CardsetFactory.create_batch(db_session, 5)
|
||||
|
||||
# All should have unique IDs and names
|
||||
ids = [c.id for c in cardsets]
|
||||
names = [c.name for c in cardsets]
|
||||
|
||||
assert len(set(ids)) == 5 # All unique IDs
|
||||
assert len(set(names)) == 5 # All unique names
|
||||
|
||||
def test_filtering_by_ranked_status(self, db_session):
|
||||
"""Test filtering cardsets by ranked legal status."""
|
||||
# Create mixed data
|
||||
CardsetFactory.create_batch(db_session, 2, ranked_legal=True)
|
||||
CardsetFactory.create_batch(db_session, 3, ranked_legal=False)
|
||||
|
||||
# Test filtering
|
||||
all_cardsets = db_session.exec(select(Cardset)).all()
|
||||
ranked_cardsets = db_session.exec(
|
||||
select(Cardset).where(Cardset.ranked_legal == True)
|
||||
).all()
|
||||
casual_cardsets = db_session.exec(
|
||||
select(Cardset).where(Cardset.ranked_legal == False)
|
||||
).all()
|
||||
|
||||
assert len(all_cardsets) == 5
|
||||
assert len(ranked_cardsets) == 2
|
||||
assert len(casual_cardsets) == 3
|
||||
|
||||
def test_name_search_functionality(self, db_session):
|
||||
"""Test searching cardsets by name patterns."""
|
||||
# Create cardsets with specific names
|
||||
season_sets = CardsetFactory.create_batch(
|
||||
db_session, 2,
|
||||
name="2024 Season"
|
||||
)
|
||||
historic_sets = CardsetFactory.create_batch(
|
||||
db_session, 2,
|
||||
name="Historic Collection"
|
||||
)
|
||||
|
||||
# Search by name pattern
|
||||
season_results = db_session.exec(
|
||||
select(Cardset).where(Cardset.name.contains("Season"))
|
||||
).all()
|
||||
historic_results = db_session.exec(
|
||||
select(Cardset).where(Cardset.name.contains("Historic"))
|
||||
).all()
|
||||
|
||||
assert len(season_results) == 2
|
||||
assert len(historic_results) == 2
|
||||
assert all("Season" in c.name for c in season_results)
|
||||
assert all("Historic" in c.name for c in historic_results)
|
||||
244
tests/unit/models/test_manager_ai.py
Normal file
244
tests/unit/models/test_manager_ai.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""
|
||||
Unit tests for ManagerAi model.
|
||||
|
||||
Tests data validation, field constraints, and model behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from app.models.manager_ai import ManagerAi, ManagerAiBase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""Create PostgreSQL test database connection."""
|
||||
# Use test database on port 5434
|
||||
test_url = "postgresql://paper_dynasty_user:paper_dynasty_test_password@localhost:5434/paper_dynasty_test"
|
||||
engine = create_engine(test_url, echo=False)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(test_db):
|
||||
"""Create database session for testing."""
|
||||
with Session(test_db) as session:
|
||||
yield session
|
||||
# Clean up after each test
|
||||
session.rollback()
|
||||
|
||||
|
||||
class TestManagerAiBase:
|
||||
"""Test ManagerAiBase model validation."""
|
||||
|
||||
def test_create_with_defaults(self):
|
||||
"""Test creating ManagerAi with default values."""
|
||||
ai = ManagerAiBase(name="Test AI")
|
||||
|
||||
assert ai.name == "Test AI"
|
||||
assert ai.steal == 5
|
||||
assert ai.running == 5
|
||||
assert ai.hold == 5
|
||||
assert ai.catcher_throw == 5
|
||||
assert ai.uncapped_home == 5
|
||||
assert ai.uncapped_third == 5
|
||||
assert ai.uncapped_trail == 5
|
||||
assert ai.bullpen_matchup == 5
|
||||
assert ai.behind_aggression == 5
|
||||
assert ai.ahead_aggression == 5
|
||||
assert ai.decide_throw == 5
|
||||
|
||||
def test_create_with_custom_values(self):
|
||||
"""Test creating ManagerAi with custom values."""
|
||||
ai = ManagerAiBase(
|
||||
name="Aggressive AI",
|
||||
steal=10,
|
||||
running=8,
|
||||
hold=3,
|
||||
behind_aggression=9,
|
||||
ahead_aggression=2
|
||||
)
|
||||
|
||||
assert ai.name == "Aggressive AI"
|
||||
assert ai.steal == 10
|
||||
assert ai.running == 8
|
||||
assert ai.hold == 3
|
||||
assert ai.behind_aggression == 9
|
||||
assert ai.ahead_aggression == 2
|
||||
|
||||
def test_validate_field_ranges(self):
|
||||
"""Test field validation constraints."""
|
||||
# Valid values at boundaries
|
||||
ai = ManagerAiBase(
|
||||
name="Boundary Test",
|
||||
steal=1,
|
||||
running=10,
|
||||
hold=1
|
||||
)
|
||||
assert ai.steal == 1
|
||||
assert ai.running == 10
|
||||
assert ai.hold == 1
|
||||
|
||||
def test_invalid_field_values(self):
|
||||
"""Test that invalid field values raise ValidationError."""
|
||||
# Values below minimum
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ManagerAiBase(name="Invalid", steal=0)
|
||||
assert "Input should be greater than or equal to 1" in str(exc_info.value)
|
||||
|
||||
# Values above maximum
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ManagerAiBase(name="Invalid", steal=11)
|
||||
assert "Input should be less than or equal to 10" in str(exc_info.value)
|
||||
|
||||
def test_required_name_field(self):
|
||||
"""Test that name field is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ManagerAiBase()
|
||||
assert "Field required" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestManagerAi:
|
||||
"""Test ManagerAi table model."""
|
||||
|
||||
def test_create_and_save(self, session):
|
||||
"""Test creating and saving ManagerAi to database."""
|
||||
ai = ManagerAi(
|
||||
name="Test AI",
|
||||
steal=7,
|
||||
running=6,
|
||||
hold=4
|
||||
)
|
||||
|
||||
session.add(ai)
|
||||
session.commit()
|
||||
session.refresh(ai)
|
||||
|
||||
assert ai.id is not None
|
||||
assert ai.name == "Test AI"
|
||||
assert ai.steal == 7
|
||||
|
||||
def test_retrieve_from_database(self, session):
|
||||
"""Test retrieving ManagerAi from database."""
|
||||
# Create and save
|
||||
ai = ManagerAi(name="Retrieval Test", steal=8)
|
||||
session.add(ai)
|
||||
session.commit()
|
||||
|
||||
# Retrieve
|
||||
retrieved = session.get(ManagerAi, ai.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Retrieval Test"
|
||||
assert retrieved.steal == 8
|
||||
|
||||
def test_update_values(self, session):
|
||||
"""Test updating ManagerAi values."""
|
||||
ai = ManagerAi(name="Update Test")
|
||||
session.add(ai)
|
||||
session.commit()
|
||||
|
||||
# Update values
|
||||
ai.steal = 9
|
||||
ai.running = 7
|
||||
session.commit()
|
||||
|
||||
# Verify updates
|
||||
session.refresh(ai)
|
||||
assert ai.steal == 9
|
||||
assert ai.running == 7
|
||||
|
||||
def test_multiple_instances(self, session):
|
||||
"""Test creating multiple ManagerAi instances."""
|
||||
ai1 = ManagerAi(name="Balanced", steal=5, running=5)
|
||||
ai2 = ManagerAi(name="Aggressive", steal=10, running=10)
|
||||
ai3 = ManagerAi(name="Conservative", steal=1, running=1)
|
||||
|
||||
session.add_all([ai1, ai2, ai3])
|
||||
session.commit()
|
||||
|
||||
# Verify all saved with different IDs
|
||||
assert ai1.id != ai2.id != ai3.id
|
||||
assert ai1.name == "Balanced"
|
||||
assert ai2.steal == 10
|
||||
assert ai3.running == 1
|
||||
|
||||
def test_field_descriptions(self):
|
||||
"""Test that field descriptions are properly set."""
|
||||
ai = ManagerAi(name="Description Test")
|
||||
|
||||
# Access field descriptions through the model class using Pydantic v2
|
||||
fields = ManagerAi.model_fields
|
||||
assert "AI steal aggression level" in str(fields['steal'])
|
||||
assert "AI base running aggression" in str(fields['running'])
|
||||
assert "AI pitcher hold tendency" in str(fields['hold'])
|
||||
|
||||
|
||||
class TestManagerAiPresets:
|
||||
"""Test creating preset ManagerAi configurations."""
|
||||
|
||||
def test_balanced_preset(self, session):
|
||||
"""Test creating a balanced AI preset."""
|
||||
balanced = ManagerAi(name="Balanced")
|
||||
session.add(balanced)
|
||||
session.commit()
|
||||
|
||||
# All defaults should be 5 (balanced)
|
||||
assert all(getattr(balanced, field) == 5 for field in [
|
||||
'steal', 'running', 'hold', 'catcher_throw',
|
||||
'uncapped_home', 'uncapped_third', 'uncapped_trail',
|
||||
'bullpen_matchup', 'behind_aggression', 'ahead_aggression',
|
||||
'decide_throw'
|
||||
])
|
||||
|
||||
def test_yolo_preset(self, session):
|
||||
"""Test creating an aggressive 'YOLO' AI preset."""
|
||||
yolo = ManagerAi(
|
||||
name="Yolo",
|
||||
steal=10,
|
||||
running=10,
|
||||
hold=5,
|
||||
catcher_throw=10,
|
||||
uncapped_home=10,
|
||||
uncapped_third=10,
|
||||
uncapped_trail=10,
|
||||
bullpen_matchup=3,
|
||||
behind_aggression=10,
|
||||
ahead_aggression=10,
|
||||
decide_throw=10
|
||||
)
|
||||
session.add(yolo)
|
||||
session.commit()
|
||||
|
||||
assert yolo.steal == 10
|
||||
assert yolo.running == 10
|
||||
assert yolo.bullpen_matchup == 3 # Conservative on bullpen
|
||||
assert yolo.behind_aggression == 10
|
||||
assert yolo.ahead_aggression == 10
|
||||
|
||||
def test_safe_preset(self, session):
|
||||
"""Test creating a conservative 'Safe' AI preset."""
|
||||
safe = ManagerAi(
|
||||
name="Safe",
|
||||
steal=3,
|
||||
running=3,
|
||||
hold=8,
|
||||
catcher_throw=5,
|
||||
uncapped_home=5,
|
||||
uncapped_third=3,
|
||||
uncapped_trail=5,
|
||||
bullpen_matchup=8,
|
||||
behind_aggression=5,
|
||||
ahead_aggression=1,
|
||||
decide_throw=1
|
||||
)
|
||||
session.add(safe)
|
||||
session.commit()
|
||||
|
||||
assert safe.steal == 3
|
||||
assert safe.running == 3
|
||||
assert safe.hold == 8 # High hold tendency
|
||||
assert safe.bullpen_matchup == 8 # Conservative bullpen usage
|
||||
assert safe.ahead_aggression == 1 # Very conservative when ahead
|
||||
assert safe.decide_throw == 1
|
||||
341
tests/unit/models/test_team.py
Normal file
341
tests/unit/models/test_team.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""
|
||||
Unit tests for Team model.
|
||||
|
||||
Tests data validation, field constraints, model behavior, and properties.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, SQLModel, create_engine, text, select, func
|
||||
|
||||
from app.models.team import Team, TeamBase
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
|
||||
|
||||
# Using centralized fixtures from conftest.py for proper test isolation
|
||||
|
||||
|
||||
class TestTeamBase:
|
||||
"""Test TeamBase model validation."""
|
||||
|
||||
def test_create_with_minimal_required_fields(self):
|
||||
"""Test creating Team with minimal required fields."""
|
||||
team = TeamBase(
|
||||
id=12345,
|
||||
abbrev="LAD",
|
||||
sname="Dodgers",
|
||||
lname="Los Angeles Dodgers",
|
||||
gmid=67890,
|
||||
gmname="Test GM",
|
||||
gsheet="sheet-url",
|
||||
wallet=25000,
|
||||
team_value=100000,
|
||||
collection_value=75000,
|
||||
color="005a9c",
|
||||
season=9,
|
||||
career=1,
|
||||
ranking=15,
|
||||
has_guide=False,
|
||||
is_ai=False
|
||||
)
|
||||
|
||||
assert team.id == 12345
|
||||
assert team.abbrev == "LAD"
|
||||
assert team.sname == "Dodgers"
|
||||
assert team.lname == "Los Angeles Dodgers"
|
||||
assert team.gmid == 67890
|
||||
assert team.gmname == "Test GM"
|
||||
assert team.gsheet == "sheet-url"
|
||||
assert team.wallet == 25000
|
||||
assert team.team_value == 100000
|
||||
assert team.collection_value == 75000
|
||||
assert team.logo is None # Default value
|
||||
assert team.color == "005a9c"
|
||||
assert team.season == 9
|
||||
assert team.career == 1
|
||||
assert team.ranking == 15
|
||||
assert team.has_guide is False
|
||||
assert team.is_ai is False
|
||||
assert isinstance(team.created, datetime.datetime)
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Test creating Team with all fields including optional ones."""
|
||||
team = TeamBase(
|
||||
id=12345,
|
||||
abbrev="NYY",
|
||||
sname="Yankees",
|
||||
lname="New York Yankees",
|
||||
gmid=67890,
|
||||
gmname="Test GM",
|
||||
gsheet="sheet-url",
|
||||
wallet=50000,
|
||||
team_value=150000,
|
||||
collection_value=125000,
|
||||
logo="https://example.com/logo.png",
|
||||
color="c4ced4",
|
||||
season=9,
|
||||
career=3,
|
||||
ranking=1,
|
||||
has_guide=True,
|
||||
is_ai=True
|
||||
)
|
||||
|
||||
assert team.logo == "https://example.com/logo.png"
|
||||
assert team.has_guide is True
|
||||
assert team.is_ai is True
|
||||
|
||||
def test_description_property_human_team(self):
|
||||
"""Test description property for human team."""
|
||||
team = TeamBase(
|
||||
id=100,
|
||||
abbrev="SF",
|
||||
lname="San Francisco Giants",
|
||||
is_ai=False,
|
||||
# ... other required fields
|
||||
sname="Giants",
|
||||
gmid=200,
|
||||
gmname="Test GM",
|
||||
gsheet="sheet-url",
|
||||
wallet=25000,
|
||||
team_value=100000,
|
||||
collection_value=75000,
|
||||
color="fd5a1e",
|
||||
season=9,
|
||||
career=1,
|
||||
ranking=10,
|
||||
has_guide=False
|
||||
)
|
||||
|
||||
assert team.description == "100. SF San Francisco Giants, Human"
|
||||
|
||||
def test_description_property_ai_team(self):
|
||||
"""Test description property for AI team."""
|
||||
team = TeamBase(
|
||||
id=200,
|
||||
abbrev="AI",
|
||||
lname="AI Team",
|
||||
is_ai=True,
|
||||
# ... other required fields
|
||||
sname="AI",
|
||||
gmid=300,
|
||||
gmname="AI Manager",
|
||||
gsheet="sheet-url",
|
||||
wallet=25000,
|
||||
team_value=100000,
|
||||
collection_value=75000,
|
||||
color="000000",
|
||||
season=9,
|
||||
career=1,
|
||||
ranking=50,
|
||||
has_guide=False
|
||||
)
|
||||
|
||||
assert team.description == "200. AI AI Team, AI"
|
||||
|
||||
|
||||
class TestTeamModel:
|
||||
"""Test Team model database operations."""
|
||||
|
||||
def test_create_team_in_database(self, db_session):
|
||||
"""Test creating and saving Team to database."""
|
||||
team = TeamFactory.create(
|
||||
db_session,
|
||||
abbrev="BOS",
|
||||
lname="Boston Red Sox",
|
||||
wallet=40000
|
||||
)
|
||||
|
||||
# Verify team was saved
|
||||
assert team.id is not None
|
||||
assert team.abbrev == "BOS"
|
||||
assert team.lname == "Boston Red Sox"
|
||||
assert team.wallet == 40000
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = db_session.get(Team, team.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.abbrev == "BOS"
|
||||
|
||||
def test_team_uniqueness(self, db_session):
|
||||
"""Test that team IDs must be unique."""
|
||||
from tests.conftest import generate_unique_id
|
||||
team_id = generate_unique_id()
|
||||
|
||||
# Create first team
|
||||
team1 = TeamFactory.create(db_session, id=team_id, abbrev="T1")
|
||||
assert team1.id == team_id
|
||||
|
||||
# Attempt to create second team with same ID should fail
|
||||
with pytest.raises(Exception): # Database integrity error
|
||||
team2 = TeamFactory.build(id=team_id, abbrev="T2")
|
||||
db_session.add(team2)
|
||||
db_session.commit()
|
||||
|
||||
def test_multiple_teams_different_ids(self, db_session):
|
||||
"""Test creating multiple teams with different IDs."""
|
||||
teams = TeamFactory.build_multiple(3)
|
||||
|
||||
for team in teams:
|
||||
db_session.add(team)
|
||||
db_session.commit()
|
||||
|
||||
# Verify all teams were saved
|
||||
all_teams = db_session.exec(select(func.count(Team.id))).first()
|
||||
assert all_teams >= 3
|
||||
|
||||
def test_ai_team_factory(self):
|
||||
"""Test AI team factory creates correct AI team."""
|
||||
ai_team = TeamFactory.build_ai_team(abbrev="AI1")
|
||||
|
||||
assert ai_team.is_ai is True
|
||||
assert ai_team.abbrev == "AI1"
|
||||
assert ai_team.lname == "AI Team"
|
||||
assert ai_team.gmname == "AI Manager"
|
||||
|
||||
def test_human_team_factory(self):
|
||||
"""Test human team factory creates correct human team."""
|
||||
human_team = TeamFactory.build_human_team(abbrev="HUM1")
|
||||
|
||||
assert human_team.is_ai is False
|
||||
assert human_team.abbrev == "HUM1"
|
||||
assert human_team.lname == "Human Team"
|
||||
assert human_team.gmname == "Human Manager"
|
||||
assert human_team.wallet == 50000
|
||||
|
||||
def test_team_field_validation(self):
|
||||
"""Test field validation and constraints."""
|
||||
# Test that required fields are actually required
|
||||
with pytest.raises(ValidationError):
|
||||
TeamBase() # Missing required fields
|
||||
|
||||
# Test that fields accept expected types
|
||||
team = TeamBase(
|
||||
id=12345,
|
||||
abbrev="TEST",
|
||||
sname="Test",
|
||||
lname="Test Team",
|
||||
gmid=67890,
|
||||
gmname="Test GM",
|
||||
gsheet="sheet-url",
|
||||
wallet=25000,
|
||||
team_value=100000,
|
||||
collection_value=75000,
|
||||
color="ff0000",
|
||||
season=9,
|
||||
career=1,
|
||||
ranking=15,
|
||||
has_guide=False,
|
||||
is_ai=False
|
||||
)
|
||||
|
||||
# Verify types are preserved
|
||||
assert isinstance(team.id, int)
|
||||
assert isinstance(team.wallet, int)
|
||||
assert isinstance(team.has_guide, bool)
|
||||
assert isinstance(team.is_ai, bool)
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that default values are applied correctly."""
|
||||
team = TeamBase(
|
||||
id=12345,
|
||||
abbrev="TEST",
|
||||
sname="Test",
|
||||
lname="Test Team",
|
||||
gmid=67890,
|
||||
gmname="Test GM",
|
||||
gsheet="sheet-url",
|
||||
wallet=25000,
|
||||
team_value=100000,
|
||||
collection_value=75000,
|
||||
color="ff0000",
|
||||
season=9,
|
||||
career=1,
|
||||
ranking=15,
|
||||
has_guide=False,
|
||||
is_ai=False
|
||||
)
|
||||
|
||||
# Test default values
|
||||
assert team.logo is None
|
||||
assert isinstance(team.created, datetime.datetime)
|
||||
|
||||
# Test that created timestamp is recent (within last minute)
|
||||
now = datetime.datetime.now()
|
||||
assert (now - team.created).total_seconds() < 60
|
||||
|
||||
def test_team_string_representations(self):
|
||||
"""Test various string field scenarios."""
|
||||
team = TeamFactory.build(
|
||||
abbrev="LONG", # Test longer abbreviations
|
||||
sname="S", # Test single character
|
||||
lname="Very Long Team Name Here", # Test long names
|
||||
gmname="Manager with Spaces",
|
||||
color="ffffff" # Test hex color
|
||||
)
|
||||
|
||||
assert len(team.abbrev) > 3
|
||||
assert len(team.sname) == 1
|
||||
assert " " in team.lname
|
||||
assert " " in team.gmname
|
||||
assert team.color == "ffffff"
|
||||
|
||||
|
||||
class TestTeamFactoryEdgeCases:
|
||||
"""Test edge cases and advanced scenarios with TeamFactory."""
|
||||
|
||||
def test_factory_override_all_defaults(self):
|
||||
"""Test that all factory defaults can be overridden."""
|
||||
custom_team = TeamFactory.build(
|
||||
id=999,
|
||||
abbrev="CUST",
|
||||
sname="Custom",
|
||||
lname="Custom Team",
|
||||
gmid=888,
|
||||
gmname="Custom GM",
|
||||
gsheet="custom-sheet",
|
||||
wallet=99999,
|
||||
team_value=200000,
|
||||
collection_value=150000,
|
||||
logo="custom-logo.png",
|
||||
color="abcdef",
|
||||
season=10,
|
||||
career=5,
|
||||
ranking=1,
|
||||
has_guide=True,
|
||||
is_ai=True
|
||||
)
|
||||
|
||||
assert custom_team.id == 999
|
||||
assert custom_team.abbrev == "CUST"
|
||||
assert custom_team.wallet == 99999
|
||||
assert custom_team.season == 10
|
||||
assert custom_team.career == 5
|
||||
assert custom_team.is_ai is True
|
||||
|
||||
def test_multiple_teams_unique_ids(self):
|
||||
"""Test that multiple teams have unique IDs."""
|
||||
teams = TeamFactory.build_multiple(5)
|
||||
|
||||
ids = [team.id for team in teams]
|
||||
assert len(set(ids)) == 5 # All IDs should be unique
|
||||
|
||||
gmids = [team.gmid for team in teams]
|
||||
assert len(set(gmids)) == 5 # All GM IDs should be unique
|
||||
|
||||
def test_factory_build_vs_create(self, db_session):
|
||||
"""Test difference between build (unsaved) and create (saved)."""
|
||||
# Build doesn't save to database
|
||||
built_team = TeamFactory.build(abbrev="BUILD")
|
||||
|
||||
# Should not exist in database yet
|
||||
retrieved = db_session.get(Team, built_team.id)
|
||||
assert retrieved is None
|
||||
|
||||
# Create saves to database
|
||||
created_team = TeamFactory.create(db_session, abbrev="CREATE")
|
||||
|
||||
# Should exist in database
|
||||
retrieved = db_session.get(Team, created_team.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.abbrev == "CREATE"
|
||||
441
tests/unit/services/test_ai_service.py
Normal file
441
tests/unit/services/test_ai_service.py
Normal file
@ -0,0 +1,441 @@
|
||||
"""
|
||||
Unit tests for AIService.
|
||||
|
||||
Tests AI decision-making business logic extracted from ManagerAi model.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.services.ai_service import AIService
|
||||
from app.models.manager_ai import ManagerAi
|
||||
from app.models.ai_responses import (
|
||||
JumpResponse,
|
||||
TagResponse,
|
||||
ThrowResponse,
|
||||
UncappedRunResponse,
|
||||
DefenseResponse,
|
||||
RunResponse,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create mock database session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_service(mock_session):
|
||||
"""Create AIService instance with mocked session."""
|
||||
return AIService(mock_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def balanced_ai():
|
||||
"""Create balanced ManagerAi configuration."""
|
||||
return ManagerAi(
|
||||
name="Balanced",
|
||||
steal=5,
|
||||
running=5,
|
||||
hold=5,
|
||||
catcher_throw=5,
|
||||
uncapped_home=5,
|
||||
uncapped_third=5,
|
||||
uncapped_trail=5,
|
||||
bullpen_matchup=5,
|
||||
behind_aggression=5,
|
||||
ahead_aggression=5,
|
||||
decide_throw=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aggressive_ai():
|
||||
"""Create aggressive ManagerAi configuration."""
|
||||
return ManagerAi(
|
||||
name="Yolo",
|
||||
steal=10,
|
||||
running=10,
|
||||
hold=5,
|
||||
catcher_throw=10,
|
||||
uncapped_home=10,
|
||||
uncapped_third=10,
|
||||
uncapped_trail=10,
|
||||
bullpen_matchup=3,
|
||||
behind_aggression=10,
|
||||
ahead_aggression=10,
|
||||
decide_throw=10
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conservative_ai():
|
||||
"""Create conservative ManagerAi configuration."""
|
||||
return ManagerAi(
|
||||
name="Safe",
|
||||
steal=3,
|
||||
running=3,
|
||||
hold=8,
|
||||
catcher_throw=5,
|
||||
uncapped_home=5,
|
||||
uncapped_third=3,
|
||||
uncapped_trail=5,
|
||||
bullpen_matchup=8,
|
||||
behind_aggression=5,
|
||||
ahead_aggression=1,
|
||||
decide_throw=1
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_game():
|
||||
"""Create mock game object."""
|
||||
game = Mock()
|
||||
game.id = 1
|
||||
game.ai_team = 'home'
|
||||
return game
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_play():
|
||||
"""Create mock play object."""
|
||||
play = Mock()
|
||||
play.starting_outs = 0
|
||||
play.outs = 0
|
||||
play.away_score = 3
|
||||
play.home_score = 3
|
||||
play.inning_num = 5
|
||||
play.on_base_code = 1
|
||||
play.ai_run_diff = 0
|
||||
play.could_walkoff = False
|
||||
play.is_new_inning = False
|
||||
|
||||
# Mock runners
|
||||
play.on_first = Mock()
|
||||
play.on_first.player.name = "Runner One"
|
||||
play.on_first.card.batterscouting.battingcard.steal_auto = False
|
||||
play.on_first.card.batterscouting.battingcard.steal_high = 15
|
||||
play.on_first.card.batterscouting.battingcard.steal_low = 12
|
||||
|
||||
play.on_second = Mock()
|
||||
play.on_second.player.name = "Runner Two"
|
||||
play.on_second.card.batterscouting.battingcard.steal_auto = False
|
||||
play.on_second.card.batterscouting.battingcard.steal_low = 10
|
||||
|
||||
play.on_third = Mock()
|
||||
play.on_third.player.name = "Runner Three"
|
||||
play.on_third.card.batterscouting.battingcard.steal_low = 8
|
||||
|
||||
# Mock pitcher and catcher
|
||||
play.pitcher.card.pitcherscouting.pitchingcard.hold = 3
|
||||
play.catcher.player_id = 100
|
||||
play.catcher.card.variant = 0
|
||||
|
||||
return play
|
||||
|
||||
|
||||
class TestAIServiceInitialization:
|
||||
"""Test AIService initialization and basic functionality."""
|
||||
|
||||
def test_initialization(self, mock_session):
|
||||
"""Test AIService initializes correctly."""
|
||||
service = AIService(mock_session)
|
||||
assert service.session == mock_session
|
||||
assert service.logger is not None
|
||||
|
||||
def test_inherits_from_base_service(self, ai_service):
|
||||
"""Test AIService inherits BaseService functionality."""
|
||||
assert hasattr(ai_service, '_log_operation')
|
||||
assert hasattr(ai_service, '_log_error')
|
||||
assert hasattr(ai_service, '_validate_required_fields')
|
||||
|
||||
|
||||
class TestCheckStealOpportunity:
|
||||
"""Test check_steal_opportunity method."""
|
||||
|
||||
def test_steal_to_second_aggressive(self, ai_service, aggressive_ai, mock_game, mock_play):
|
||||
"""Test steal decision to second base with aggressive AI."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
mock_catcher_defense = Mock()
|
||||
mock_catcher_defense.arm = 5
|
||||
ai_service.session.exec.return_value.one.return_value = mock_catcher_defense
|
||||
|
||||
result = ai_service.check_steal_opportunity(aggressive_ai, mock_game, 2)
|
||||
|
||||
assert isinstance(result, JumpResponse)
|
||||
assert result.min_safe == 12 # 12 + 0 outs for steal=10
|
||||
assert result.run_if_auto_jump is True # steal > 7
|
||||
|
||||
def test_steal_to_second_conservative(self, ai_service, conservative_ai, mock_game, mock_play):
|
||||
"""Test steal decision to second base with conservative AI."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
mock_catcher_defense = Mock()
|
||||
mock_catcher_defense.arm = 5
|
||||
ai_service.session.exec.return_value.one.return_value = mock_catcher_defense
|
||||
|
||||
result = ai_service.check_steal_opportunity(conservative_ai, mock_game, 2)
|
||||
|
||||
assert isinstance(result, JumpResponse)
|
||||
assert result.min_safe == 16 # 16 + 0 outs for steal=3
|
||||
assert result.must_auto_jump is True # steal < 5
|
||||
|
||||
def test_steal_to_third(self, ai_service, aggressive_ai, mock_game, mock_play):
|
||||
"""Test steal decision to third base."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
mock_catcher_defense = Mock()
|
||||
mock_catcher_defense.arm = 5
|
||||
ai_service.session.exec.return_value.one.return_value = mock_catcher_defense
|
||||
|
||||
result = ai_service.check_steal_opportunity(aggressive_ai, mock_game, 3)
|
||||
|
||||
assert isinstance(result, JumpResponse)
|
||||
assert result.min_safe == 12 # 12 + 0 outs for steal=10
|
||||
assert result.run_if_auto_jump is True
|
||||
|
||||
def test_no_current_play_raises_error(self, ai_service, balanced_ai, mock_game):
|
||||
"""Test that missing current play raises ValueError."""
|
||||
mock_game.current_play_or_none.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="No game found while checking for steal"):
|
||||
ai_service.check_steal_opportunity(balanced_ai, mock_game, 2)
|
||||
|
||||
def test_no_runner_on_first_raises_error(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test that missing runner on first raises ValueError."""
|
||||
mock_play.on_first = None
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
mock_catcher_defense = Mock()
|
||||
mock_catcher_defense.arm = 5
|
||||
ai_service.session.exec.return_value.one.return_value = mock_catcher_defense
|
||||
|
||||
with pytest.raises(ValueError, match="no runner found on first"):
|
||||
ai_service.check_steal_opportunity(balanced_ai, mock_game, 2)
|
||||
|
||||
|
||||
class TestTagDecisions:
|
||||
"""Test tag-up decision methods."""
|
||||
|
||||
def test_tag_from_second_aggressive(self, ai_service, aggressive_ai, mock_game, mock_play):
|
||||
"""Test tag from second with aggressive AI."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.check_tag_from_second(aggressive_ai, mock_game)
|
||||
|
||||
assert isinstance(result, TagResponse)
|
||||
# aggressive_ai.running=10 + aggression_mod=5 = 15 >= 8, so min_safe=4
|
||||
# starting_outs=0 != 1, so +2, final=6
|
||||
assert result.min_safe == 6
|
||||
|
||||
def test_tag_from_second_conservative(self, ai_service, conservative_ai, mock_game, mock_play):
|
||||
"""Test tag from second with conservative AI."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.check_tag_from_second(conservative_ai, mock_game)
|
||||
|
||||
assert isinstance(result, TagResponse)
|
||||
# conservative_ai.running=3 + aggression_mod=4 = 7 < 8, so min_safe=10
|
||||
# starting_outs=0 != 1, so +2, final=12
|
||||
assert result.min_safe == 12
|
||||
|
||||
def test_tag_from_third_one_out(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test tag from third with one out."""
|
||||
mock_play.starting_outs = 1
|
||||
mock_play.ai_run_diff = 2 # Not in [-1, 0] range to avoid extra -2
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.check_tag_from_third(balanced_ai, mock_game)
|
||||
|
||||
assert isinstance(result, TagResponse)
|
||||
# balanced_ai.running=5 + aggression_mod=0 = 5 < 8, so min_safe=10
|
||||
# starting_outs=1, so -2, final=8
|
||||
assert result.min_safe == 8
|
||||
|
||||
|
||||
class TestThrowDecisions:
|
||||
"""Test throw target decision methods."""
|
||||
|
||||
def test_throw_decision_big_lead(self, ai_service, aggressive_ai, mock_game, mock_play):
|
||||
"""Test throw decision when AI has big lead."""
|
||||
mock_play.ai_run_diff = 6 # Big lead
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.decide_throw_target(aggressive_ai, mock_game)
|
||||
|
||||
assert isinstance(result, ThrowResponse)
|
||||
assert result.at_trail_runner is True
|
||||
assert result.trail_max_safe_delta == -4 # -4 + 0 current_outs
|
||||
|
||||
def test_throw_decision_close_game(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test throw decision in close game."""
|
||||
mock_play.ai_run_diff = 0 # Tied game
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.decide_throw_target(balanced_ai, mock_game)
|
||||
|
||||
assert isinstance(result, ThrowResponse)
|
||||
# Default values for close game with balanced AI
|
||||
assert result.at_trail_runner is False
|
||||
assert result.cutoff is False
|
||||
|
||||
|
||||
class TestRunnerAdvanceDecisions:
|
||||
"""Test runner advance decision methods."""
|
||||
|
||||
def test_uncapped_advance_to_home(self, ai_service, aggressive_ai, mock_game, mock_play):
|
||||
"""Test uncapped advance decision for runner going home."""
|
||||
mock_play.ai_run_diff = 2
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.decide_runner_advance(aggressive_ai, mock_game, 4, 3)
|
||||
|
||||
assert isinstance(result, UncappedRunResponse)
|
||||
# ai_rd=2, lead_base=4: min_safe = 12 - 0 - 5 = 7
|
||||
assert result.min_safe == 7
|
||||
assert result.send_trail is True
|
||||
|
||||
def test_uncapped_advance_bounds_checking(self, ai_service, aggressive_ai, mock_game, mock_play):
|
||||
"""Test that advance decisions respect bounds."""
|
||||
mock_play.ai_run_diff = -10 # Way behind
|
||||
mock_play.starting_outs = 2
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.decide_runner_advance(aggressive_ai, mock_game, 4, 3)
|
||||
|
||||
assert isinstance(result, UncappedRunResponse)
|
||||
# Should be bounded between 1 and 20
|
||||
assert 1 <= result.min_safe <= 20
|
||||
assert 1 <= result.trail_min_safe <= 20
|
||||
|
||||
|
||||
class TestDefensiveAlignment:
|
||||
"""Test defensive alignment decisions."""
|
||||
|
||||
def test_defense_with_runner_on_third_walkoff(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test defensive alignment with walkoff situation."""
|
||||
mock_play.on_third = Mock()
|
||||
mock_play.on_third.player.name = "Walkoff Runner"
|
||||
mock_play.could_walkoff = True
|
||||
mock_play.starting_outs = 1
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
mock_catcher_defense = Mock()
|
||||
mock_catcher_defense.arm = 5
|
||||
ai_service.session.exec.return_value.one.return_value = mock_catcher_defense
|
||||
|
||||
result = ai_service.set_defensive_alignment(balanced_ai, mock_game)
|
||||
|
||||
assert isinstance(result, DefenseResponse)
|
||||
assert result.outfield_in is True
|
||||
assert result.infield_in is True
|
||||
assert "play the outfield and infield in" in result.ai_note
|
||||
|
||||
def test_defense_two_outs_hold_runners(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test defensive holds with two outs."""
|
||||
mock_play.starting_outs = 2
|
||||
mock_play.on_base_code = 1 # Runner on first
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
mock_catcher_defense = Mock()
|
||||
mock_catcher_defense.arm = 5
|
||||
ai_service.session.exec.return_value.one.return_value = mock_catcher_defense
|
||||
|
||||
result = ai_service.set_defensive_alignment(balanced_ai, mock_game)
|
||||
|
||||
assert isinstance(result, DefenseResponse)
|
||||
assert result.hold_first is True
|
||||
assert "hold Runner One on 1st" in result.ai_note
|
||||
|
||||
|
||||
class TestGroundballDecisions:
|
||||
"""Test groundball-specific decisions."""
|
||||
|
||||
def test_groundball_running_decision(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test groundball running decision."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.decide_groundball_running(balanced_ai, mock_game)
|
||||
|
||||
assert isinstance(result, RunResponse)
|
||||
# min_safe = 15 - aggression(0) = 15
|
||||
assert result.min_safe == 15
|
||||
|
||||
def test_groundball_throw_decision(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test groundball throw decision."""
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
result = ai_service.decide_groundball_throw(balanced_ai, mock_game, 10, 3)
|
||||
|
||||
assert isinstance(result, ThrowResponse)
|
||||
# (10 - 4 + 3) = 9 <= (10 + 0) = 10, so at_lead_runner=True
|
||||
assert result.at_lead_runner is True
|
||||
|
||||
|
||||
class TestPitcherReplacement:
|
||||
"""Test pitcher replacement decisions."""
|
||||
|
||||
def test_should_replace_fatigued_starter(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test pitcher replacement for fatigued starter."""
|
||||
mock_play.pitcher.replacing_id = None # This is a starter
|
||||
mock_play.pitcher.is_fatigued = True
|
||||
mock_play.on_base_code = 2 # Runners on base
|
||||
mock_play.pitcher.card.pitcherscouting.pitchingcard.starter_rating = 5
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
# Mock database queries
|
||||
ai_service.session.exec.return_value.one.side_effect = [18, 6] # 18 outs, 6 allowed runners
|
||||
|
||||
result = ai_service.should_replace_pitcher(balanced_ai, mock_game)
|
||||
|
||||
assert result is True # Fatigued starter with runners should be replaced
|
||||
|
||||
def test_should_keep_effective_starter(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test keeping effective starter."""
|
||||
mock_play.pitcher.replacing_id = None # This is a starter
|
||||
mock_play.pitcher.is_fatigued = False
|
||||
mock_play.on_base_code = 0 # No runners
|
||||
mock_play.pitcher.card.pitcherscouting.pitchingcard.starter_rating = 6
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
# Mock database queries - effective pitcher
|
||||
ai_service.session.exec.return_value.one.side_effect = [15, 2] # 15 outs, 2 allowed runners
|
||||
|
||||
result = ai_service.should_replace_pitcher(balanced_ai, mock_game)
|
||||
|
||||
assert result is False # Effective starter should stay in
|
||||
|
||||
def test_should_replace_overworked_reliever(self, ai_service, balanced_ai, mock_game, mock_play):
|
||||
"""Test replacing overworked reliever."""
|
||||
mock_play.pitcher.replacing_id = 123 # This is a reliever
|
||||
mock_play.pitcher.card.pitcherscouting.pitchingcard.relief_rating = 3
|
||||
mock_game.current_play_or_none.return_value = mock_play
|
||||
|
||||
# Mock database queries - overworked reliever
|
||||
ai_service.session.exec.return_value.one.side_effect = [12, 4] # 12 outs (4 IP), 4 allowed runners
|
||||
|
||||
result = ai_service.should_replace_pitcher(balanced_ai, mock_game)
|
||||
|
||||
assert result is True # Overworked reliever should be replaced
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in AIService methods."""
|
||||
|
||||
def test_methods_handle_no_current_play(self, ai_service, balanced_ai, mock_game):
|
||||
"""Test that all methods handle missing current play gracefully."""
|
||||
mock_game.current_play_or_none.return_value = None
|
||||
|
||||
methods_to_test = [
|
||||
(ai_service.check_tag_from_second, (balanced_ai, mock_game)),
|
||||
(ai_service.check_tag_from_third, (balanced_ai, mock_game)),
|
||||
(ai_service.decide_throw_target, (balanced_ai, mock_game)),
|
||||
(ai_service.decide_runner_advance, (balanced_ai, mock_game, 4, 3)),
|
||||
(ai_service.set_defensive_alignment, (balanced_ai, mock_game)),
|
||||
(ai_service.decide_groundball_running, (balanced_ai, mock_game)),
|
||||
(ai_service.decide_groundball_throw, (balanced_ai, mock_game, 10, 3)),
|
||||
(ai_service.should_replace_pitcher, (balanced_ai, mock_game)),
|
||||
]
|
||||
|
||||
for method, args in methods_to_test:
|
||||
with pytest.raises(ValueError, match="No game found"):
|
||||
method(*args)
|
||||
170
tests/unit/services/test_ui_service.py
Normal file
170
tests/unit/services/test_ui_service.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""
|
||||
Unit tests for UIService.
|
||||
|
||||
Tests business logic extracted from models, particularly team display formatting.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.services.ui_service import UIService
|
||||
from tests.factories.team_factory import TeamFactory
|
||||
|
||||
|
||||
class TestUIService:
|
||||
"""Test UIService functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock database session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def ui_service(self, mock_session):
|
||||
"""Create UIService instance with mock session."""
|
||||
return UIService(mock_session)
|
||||
|
||||
def test_format_team_display_with_custom_logo_and_color(self, ui_service):
|
||||
"""Test team display formatting with custom logo and color."""
|
||||
team = TeamFactory.build(
|
||||
id=123,
|
||||
abbrev="LAD",
|
||||
lname="Los Angeles Dodgers",
|
||||
logo="https://example.com/dodgers-logo.png",
|
||||
color="005a9c",
|
||||
season=9
|
||||
)
|
||||
|
||||
result = ui_service.format_team_display(team)
|
||||
|
||||
assert result['title'] == "Los Angeles Dodgers"
|
||||
assert result['color'] == "005a9c"
|
||||
assert result['footer_text'] == "Paper Dynasty Season 9"
|
||||
assert result['footer_icon'] == "https://paper-dynasty.s3.us-east-1.amazonaws.com/static-images/sba-logo.png"
|
||||
assert result['thumbnail'] == "https://example.com/dodgers-logo.png"
|
||||
assert result['team_id'] == 123
|
||||
assert result['abbrev'] == "LAD"
|
||||
assert result['season'] == 9
|
||||
|
||||
def test_format_team_display_with_defaults(self, ui_service):
|
||||
"""Test team display formatting with default logo and color."""
|
||||
team = TeamFactory.build(
|
||||
id=456,
|
||||
abbrev="SF",
|
||||
lname="San Francisco Giants",
|
||||
logo=None, # Should use default
|
||||
color=None, # Should use default
|
||||
season=8
|
||||
)
|
||||
|
||||
result = ui_service.format_team_display(team)
|
||||
|
||||
# Should use default values
|
||||
assert result['color'] == "a6ce39" # SBA_COLOR default
|
||||
assert result['thumbnail'] == "https://paper-dynasty.s3.us-east-1.amazonaws.com/static-images/sba-logo.png" # SBA_LOGO default
|
||||
assert result['title'] == "San Francisco Giants"
|
||||
assert result['team_id'] == 456
|
||||
|
||||
def test_format_team_display_with_empty_color(self, ui_service):
|
||||
"""Test team display formatting with empty string color."""
|
||||
team = TeamFactory.build(
|
||||
abbrev="NYY",
|
||||
lname="New York Yankees",
|
||||
color="", # Empty string should trigger default
|
||||
season=9
|
||||
)
|
||||
|
||||
result = ui_service.format_team_display(team)
|
||||
|
||||
# Empty string should trigger default color
|
||||
assert result['color'] == "a6ce39"
|
||||
|
||||
def test_format_team_display_ai_team(self, ui_service):
|
||||
"""Test team display formatting for AI team."""
|
||||
ai_team = TeamFactory.build_ai_team(
|
||||
id=789,
|
||||
abbrev="AI1",
|
||||
season=10
|
||||
)
|
||||
|
||||
result = ui_service.format_team_display(ai_team)
|
||||
|
||||
assert result['title'] == "AI Team"
|
||||
assert result['team_id'] == 789
|
||||
assert result['abbrev'] == "AI1"
|
||||
assert result['season'] == 10
|
||||
|
||||
def test_format_team_display_different_seasons(self, ui_service):
|
||||
"""Test team display formatting across different seasons."""
|
||||
team_s8 = TeamFactory.build(lname="Season 8 Team", season=8)
|
||||
team_s9 = TeamFactory.build(lname="Season 9 Team", season=9)
|
||||
team_s10 = TeamFactory.build(lname="Season 10 Team", season=10)
|
||||
|
||||
result_s8 = ui_service.format_team_display(team_s8)
|
||||
result_s9 = ui_service.format_team_display(team_s9)
|
||||
result_s10 = ui_service.format_team_display(team_s10)
|
||||
|
||||
assert result_s8['footer_text'] == "Paper Dynasty Season 8"
|
||||
assert result_s9['footer_text'] == "Paper Dynasty Season 9"
|
||||
assert result_s10['footer_text'] == "Paper Dynasty Season 10"
|
||||
|
||||
def test_format_team_display_preserves_all_fields(self, ui_service):
|
||||
"""Test that all expected fields are present in formatted output."""
|
||||
team = TeamFactory.build()
|
||||
|
||||
result = ui_service.format_team_display(team)
|
||||
|
||||
expected_fields = [
|
||||
'title', 'color', 'footer_text', 'footer_icon',
|
||||
'thumbnail', 'team_id', 'abbrev', 'season'
|
||||
]
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in result, f"Missing field: {field}"
|
||||
|
||||
def test_format_team_display_error_handling(self, ui_service):
|
||||
"""Test error handling in team display formatting."""
|
||||
# Test with None object should cause an AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
ui_service.format_team_display(None)
|
||||
|
||||
def test_format_team_display_logging(self, ui_service):
|
||||
"""Test that proper logging occurs during team formatting."""
|
||||
team = TeamFactory.build(abbrev="LOG")
|
||||
|
||||
# This should not raise an exception and should complete successfully
|
||||
result = ui_service.format_team_display(team)
|
||||
|
||||
assert result is not None
|
||||
assert 'title' in result
|
||||
|
||||
def test_format_team_display_hex_color_variations(self, ui_service):
|
||||
"""Test various hex color formats."""
|
||||
test_colors = [
|
||||
"ff0000", # 6-digit hex
|
||||
"000000", # Black
|
||||
"ffffff", # White
|
||||
"a6ce39", # Default SBA color
|
||||
"005a9c", # Dodgers blue
|
||||
]
|
||||
|
||||
for color in test_colors:
|
||||
team = TeamFactory.build(color=color)
|
||||
result = ui_service.format_team_display(team)
|
||||
assert result['color'] == color
|
||||
|
||||
def test_format_team_display_special_characters_in_names(self, ui_service):
|
||||
"""Test team names with special characters."""
|
||||
special_names = [
|
||||
"Team with Spaces",
|
||||
"Team-with-Hyphens",
|
||||
"Team's with Apostrophes",
|
||||
"Team & Ampersands",
|
||||
"Team (with Parentheses)",
|
||||
]
|
||||
|
||||
for name in special_names:
|
||||
team = TeamFactory.build(lname=name)
|
||||
result = ui_service.format_team_display(team)
|
||||
assert result['title'] == name
|
||||
Loading…
Reference in New Issue
Block a user