major-domo-database/app/services/mocks.py
root bcec206bb4 fix: Complete dependency injection for PlayerService
- Moved peewee/fastapi imports inside methods to enable testing without DB
- Added InMemoryQueryResult for mock-compatible filtering/sorting
- Updated interfaces with @runtime_checkable for isinstance() checks
- Fixed get_or_none() to accept keyword arguments
- _player_to_dict() now handles both dicts and Peewee models

Result: All 14 tests pass without database connection.
Service can now be fully tested with MockPlayerRepository.
2026-02-03 16:49:50 +00:00

330 lines
10 KiB
Python

"""
Enhanced Mock Implementations for Testing
Provides comprehensive in-memory mocks for full test coverage.
"""
from typing import List, Dict, Any, Optional, Callable
from collections import defaultdict
import time
import fnmatch
class MockQueryResult:
"""Enhanced mock query result that supports chaining and complex queries."""
def __init__(self, items: List[Dict[str, Any]]):
self._items = list(items)
self._original_items = list(items)
self._filters: List[Callable] = []
self._order_by_field = None
self._order_by_desc = False
def where(self, *conditions) -> 'MockQueryResult':
"""Apply WHERE conditions."""
result = MockQueryResult(self._original_items.copy())
result._filters = self._filters.copy()
def apply_filter(item):
for condition in conditions:
if callable(condition):
if not condition(item):
return False
elif isinstance(condition, tuple):
field, op, value = condition
item_val = item.get(field)
if op == '<<': # IN
if item_val not in value:
return False
elif op == '==':
if item_val != value:
return False
elif op == '!=':
if item_val == value:
return False
return True
filtered = [i for i in self._items if apply_filter(i)]
result._items = filtered
return result
def order_by(self, *fields) -> 'MockQueryResult':
"""Order by fields."""
result = MockQueryResult(self._items.copy())
def get_sort_key(item):
values = []
for field in fields:
neg = False
if hasattr(field, '__neg__'):
field = -field
neg = True
val = item.get(str(field), 0)
if isinstance(val, (int, float)):
values.append(-val if neg else val)
else:
values.append(val)
return tuple(values)
result._items.sort(key=get_sort_key)
return result
def count(self) -> int:
return len(self._items)
def __iter__(self):
return iter(self._items)
def __len__(self):
return len(self._items)
def __getitem__(self, index):
return self._items[index]
class EnhancedMockRepository:
"""Enhanced mock repository with full CRUD support."""
def __init__(self, name: str = "entity"):
self._data: Dict[int, Dict] = {}
self._id_counter = 1
self._name = name
self._last_query = None
def _make_id(self, item: Dict) -> int:
"""Generate or use existing ID."""
if 'id' not in item or item['id'] is None:
item['id'] = self._id_counter
self._id_counter += 1
return item['id']
def select_season(self, season: int) -> MockQueryResult:
"""Get all items for a season."""
items = [v for v in self._data.values() if v.get('season') == season]
return MockQueryResult(items)
def get_by_id(self, entity_id: int) -> Optional[Dict]:
"""Get item by ID."""
return self._data.get(entity_id)
def get_or_none(self, *conditions, **field_conditions) -> Optional[Dict]:
"""Get first item matching conditions."""
# Convert field_conditions to conditions
converted_conditions = list(conditions)
for field, value in field_conditions.items():
converted_conditions.append(lambda item, f=field, v=value: item.get(f) == v)
for item in self._data.values():
if self._matches(item, converted_conditions):
return item
return None
def _matches(self, item: Dict, conditions) -> bool:
"""Check if item matches conditions."""
for condition in conditions:
if callable(condition):
if not condition(item):
return False
return True
def update(self, data: Dict, *conditions) -> int:
"""Update items matching conditions."""
updated = 0
for item in self._data.values():
if self._matches(item, conditions):
for key, value in data.items():
item[key] = value
updated += 1
return updated
def insert_many(self, data: List[Dict]) -> int:
"""Insert multiple items."""
count = 0
for item in data:
self.add(item)
count += 1
return count
def delete_by_id(self, entity_id: int) -> int:
"""Delete item by ID."""
if entity_id in self._data:
del self._data[entity_id]
return 1
return 0
def add(self, item: Dict) -> Dict:
"""Add item to repository."""
self._make_id(item)
self._data[item['id']] = item
return item
def clear(self):
"""Clear all data."""
self._data.clear()
self._id_counter = 1
def all(self) -> List[Dict]:
"""Get all items."""
return list(self._data.values())
def count(self) -> int:
"""Count all items."""
return len(self._data)
class MockPlayerRepository(EnhancedMockRepository):
"""In-memory mock of player database."""
def __init__(self):
super().__init__("player")
def add_player(self, player: Dict) -> Dict:
"""Add player with validation."""
return self.add(player)
def select_season(self, season: int) -> MockQueryResult:
"""Get all players for a season."""
items = [p for p in self._data.values() if p.get('season') == season]
return MockQueryResult(items)
class MockTeamRepository(EnhancedMockRepository):
"""In-memory mock of team database."""
def __init__(self):
super().__init__("team")
def add_team(self, team: Dict) -> Dict:
"""Add team with validation."""
return self.add(team)
def select_season(self, season: int) -> MockQueryResult:
"""Get all teams for a season."""
items = [t for t in self._data.values() if t.get('season') == season]
return MockQueryResult(items)
class EnhancedMockCache:
"""Enhanced mock cache with call tracking and TTL support."""
def __init__(self):
self._cache: Dict[str, str] = {}
self._expiry: Dict[str, float] = {}
self._calls: List[Dict] = []
self._hit_count = 0
self._miss_count = 0
def _is_expired(self, key: str) -> bool:
"""Check if key is expired."""
if key not in self._expiry:
return False
if time.time() < self._expiry[key]:
return False
# Clean up expired key
del self._cache[key]
del self._expiry[key]
return True
def get(self, key: str) -> Optional[str]:
"""Get cached value."""
self._calls.append({'method': 'get', 'key': key})
if self._is_expired(key):
self._miss_count += 1
return None
if key in self._cache:
self._hit_count += 1
return self._cache[key]
self._miss_count += 1
return None
def set(self, key: str, value: str, ttl: int = 300) -> bool:
"""Set cached value with TTL."""
self._calls.append({
'method': 'set',
'key': key,
'value': value[:200] if isinstance(value, str) else str(value)[:200],
'ttl': ttl
})
self._cache[key] = value
self._expiry[key] = time.time() + ttl
return True
def setex(self, key: str, ttl: int, value: str) -> bool:
"""Set with explicit expiry (alias)."""
return self.set(key, value, ttl)
def keys(self, pattern: str) -> List[str]:
"""Get keys matching pattern."""
self._calls.append({'method': 'keys', 'pattern': pattern})
return [k for k in self._cache.keys() if fnmatch.fnmatch(k, pattern)]
def delete(self, *keys: str) -> int:
"""Delete specific keys."""
self._calls.append({'method': 'delete', 'keys': list(keys)})
deleted = 0
for key in keys:
if key in self._cache:
del self._cache[key]
if key in self._expiry:
del self._expiry[key]
deleted += 1
return deleted
def invalidate_pattern(self, pattern: str) -> int:
"""Delete all keys matching pattern."""
keys = self.keys(pattern)
return self.delete(*keys)
def exists(self, key: str) -> bool:
"""Check if key exists and not expired."""
if self._is_expired(key):
return False
return key in self._cache
def clear(self):
"""Clear all cached data."""
self._cache.clear()
self._expiry.clear()
self._calls.clear()
self._hit_count = 0
self._miss_count = 0
def get_calls(self, method: Optional[str] = None) -> List[Dict]:
"""Get tracked calls."""
if method:
return [c for c in self._calls if c.get('method') == method]
return list(self._calls)
def clear_calls(self):
"""Clear call history."""
self._calls.clear()
@property
def hit_rate(self) -> float:
"""Get cache hit rate."""
total = self._hit_count + self._miss_count
if total == 0:
return 0.0
return self._hit_count / total
def assert_called_with(self, method: str, **kwargs) -> bool:
"""Assert a method was called with specific args."""
for call in self._calls:
if call.get('method') == method:
for key, value in kwargs.items():
if call.get(key) != value:
break
else:
return True
available = [c.get('method') for c in self._calls]
raise AssertionError(f"Expected {method}({kwargs}) not found. Available: {available}")
def was_called(self, method: str) -> bool:
"""Check if method was called."""
return any(c.get('method') == method for c in self._calls)
class MockCacheService:
"""Alias for EnhancedMockCache for compatibility."""
def __new__(cls):
return EnhancedMockCache()