- Moved peewee/fastapi imports inside methods to enable testing without DB - Added InMemoryQueryResult for mock-compatible filtering/sorting - Updated interfaces with @runtime_checkable for isinstance() checks - Fixed get_or_none() to accept keyword arguments - _player_to_dict() now handles both dicts and Peewee models Result: All 14 tests pass without database connection. Service can now be fully tested with MockPlayerRepository.
330 lines
10 KiB
Python
330 lines
10 KiB
Python
"""
|
|
Enhanced Mock Implementations for Testing
|
|
Provides comprehensive in-memory mocks for full test coverage.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional, Callable
|
|
from collections import defaultdict
|
|
import time
|
|
import fnmatch
|
|
|
|
|
|
class MockQueryResult:
|
|
"""Enhanced mock query result that supports chaining and complex queries."""
|
|
|
|
def __init__(self, items: List[Dict[str, Any]]):
|
|
self._items = list(items)
|
|
self._original_items = list(items)
|
|
self._filters: List[Callable] = []
|
|
self._order_by_field = None
|
|
self._order_by_desc = False
|
|
|
|
def where(self, *conditions) -> 'MockQueryResult':
|
|
"""Apply WHERE conditions."""
|
|
result = MockQueryResult(self._original_items.copy())
|
|
result._filters = self._filters.copy()
|
|
|
|
def apply_filter(item):
|
|
for condition in conditions:
|
|
if callable(condition):
|
|
if not condition(item):
|
|
return False
|
|
elif isinstance(condition, tuple):
|
|
field, op, value = condition
|
|
item_val = item.get(field)
|
|
if op == '<<': # IN
|
|
if item_val not in value:
|
|
return False
|
|
elif op == '==':
|
|
if item_val != value:
|
|
return False
|
|
elif op == '!=':
|
|
if item_val == value:
|
|
return False
|
|
return True
|
|
|
|
filtered = [i for i in self._items if apply_filter(i)]
|
|
result._items = filtered
|
|
return result
|
|
|
|
def order_by(self, *fields) -> 'MockQueryResult':
|
|
"""Order by fields."""
|
|
result = MockQueryResult(self._items.copy())
|
|
|
|
def get_sort_key(item):
|
|
values = []
|
|
for field in fields:
|
|
neg = False
|
|
if hasattr(field, '__neg__'):
|
|
field = -field
|
|
neg = True
|
|
val = item.get(str(field), 0)
|
|
if isinstance(val, (int, float)):
|
|
values.append(-val if neg else val)
|
|
else:
|
|
values.append(val)
|
|
return tuple(values)
|
|
|
|
result._items.sort(key=get_sort_key)
|
|
return result
|
|
|
|
def count(self) -> int:
|
|
return len(self._items)
|
|
|
|
def __iter__(self):
|
|
return iter(self._items)
|
|
|
|
def __len__(self):
|
|
return len(self._items)
|
|
|
|
def __getitem__(self, index):
|
|
return self._items[index]
|
|
|
|
|
|
class EnhancedMockRepository:
|
|
"""Enhanced mock repository with full CRUD support."""
|
|
|
|
def __init__(self, name: str = "entity"):
|
|
self._data: Dict[int, Dict] = {}
|
|
self._id_counter = 1
|
|
self._name = name
|
|
self._last_query = None
|
|
|
|
def _make_id(self, item: Dict) -> int:
|
|
"""Generate or use existing ID."""
|
|
if 'id' not in item or item['id'] is None:
|
|
item['id'] = self._id_counter
|
|
self._id_counter += 1
|
|
return item['id']
|
|
|
|
def select_season(self, season: int) -> MockQueryResult:
|
|
"""Get all items for a season."""
|
|
items = [v for v in self._data.values() if v.get('season') == season]
|
|
return MockQueryResult(items)
|
|
|
|
def get_by_id(self, entity_id: int) -> Optional[Dict]:
|
|
"""Get item by ID."""
|
|
return self._data.get(entity_id)
|
|
|
|
def get_or_none(self, *conditions, **field_conditions) -> Optional[Dict]:
|
|
"""Get first item matching conditions."""
|
|
# Convert field_conditions to conditions
|
|
converted_conditions = list(conditions)
|
|
for field, value in field_conditions.items():
|
|
converted_conditions.append(lambda item, f=field, v=value: item.get(f) == v)
|
|
|
|
for item in self._data.values():
|
|
if self._matches(item, converted_conditions):
|
|
return item
|
|
return None
|
|
|
|
def _matches(self, item: Dict, conditions) -> bool:
|
|
"""Check if item matches conditions."""
|
|
for condition in conditions:
|
|
if callable(condition):
|
|
if not condition(item):
|
|
return False
|
|
return True
|
|
|
|
def update(self, data: Dict, *conditions) -> int:
|
|
"""Update items matching conditions."""
|
|
updated = 0
|
|
for item in self._data.values():
|
|
if self._matches(item, conditions):
|
|
for key, value in data.items():
|
|
item[key] = value
|
|
updated += 1
|
|
return updated
|
|
|
|
def insert_many(self, data: List[Dict]) -> int:
|
|
"""Insert multiple items."""
|
|
count = 0
|
|
for item in data:
|
|
self.add(item)
|
|
count += 1
|
|
return count
|
|
|
|
def delete_by_id(self, entity_id: int) -> int:
|
|
"""Delete item by ID."""
|
|
if entity_id in self._data:
|
|
del self._data[entity_id]
|
|
return 1
|
|
return 0
|
|
|
|
def add(self, item: Dict) -> Dict:
|
|
"""Add item to repository."""
|
|
self._make_id(item)
|
|
self._data[item['id']] = item
|
|
return item
|
|
|
|
def clear(self):
|
|
"""Clear all data."""
|
|
self._data.clear()
|
|
self._id_counter = 1
|
|
|
|
def all(self) -> List[Dict]:
|
|
"""Get all items."""
|
|
return list(self._data.values())
|
|
|
|
def count(self) -> int:
|
|
"""Count all items."""
|
|
return len(self._data)
|
|
|
|
|
|
class MockPlayerRepository(EnhancedMockRepository):
|
|
"""In-memory mock of player database."""
|
|
|
|
def __init__(self):
|
|
super().__init__("player")
|
|
|
|
def add_player(self, player: Dict) -> Dict:
|
|
"""Add player with validation."""
|
|
return self.add(player)
|
|
|
|
def select_season(self, season: int) -> MockQueryResult:
|
|
"""Get all players for a season."""
|
|
items = [p for p in self._data.values() if p.get('season') == season]
|
|
return MockQueryResult(items)
|
|
|
|
|
|
class MockTeamRepository(EnhancedMockRepository):
|
|
"""In-memory mock of team database."""
|
|
|
|
def __init__(self):
|
|
super().__init__("team")
|
|
|
|
def add_team(self, team: Dict) -> Dict:
|
|
"""Add team with validation."""
|
|
return self.add(team)
|
|
|
|
def select_season(self, season: int) -> MockQueryResult:
|
|
"""Get all teams for a season."""
|
|
items = [t for t in self._data.values() if t.get('season') == season]
|
|
return MockQueryResult(items)
|
|
|
|
|
|
class EnhancedMockCache:
|
|
"""Enhanced mock cache with call tracking and TTL support."""
|
|
|
|
def __init__(self):
|
|
self._cache: Dict[str, str] = {}
|
|
self._expiry: Dict[str, float] = {}
|
|
self._calls: List[Dict] = []
|
|
self._hit_count = 0
|
|
self._miss_count = 0
|
|
|
|
def _is_expired(self, key: str) -> bool:
|
|
"""Check if key is expired."""
|
|
if key not in self._expiry:
|
|
return False
|
|
if time.time() < self._expiry[key]:
|
|
return False
|
|
# Clean up expired key
|
|
del self._cache[key]
|
|
del self._expiry[key]
|
|
return True
|
|
|
|
def get(self, key: str) -> Optional[str]:
|
|
"""Get cached value."""
|
|
self._calls.append({'method': 'get', 'key': key})
|
|
if self._is_expired(key):
|
|
self._miss_count += 1
|
|
return None
|
|
if key in self._cache:
|
|
self._hit_count += 1
|
|
return self._cache[key]
|
|
self._miss_count += 1
|
|
return None
|
|
|
|
def set(self, key: str, value: str, ttl: int = 300) -> bool:
|
|
"""Set cached value with TTL."""
|
|
self._calls.append({
|
|
'method': 'set',
|
|
'key': key,
|
|
'value': value[:200] if isinstance(value, str) else str(value)[:200],
|
|
'ttl': ttl
|
|
})
|
|
self._cache[key] = value
|
|
self._expiry[key] = time.time() + ttl
|
|
return True
|
|
|
|
def setex(self, key: str, ttl: int, value: str) -> bool:
|
|
"""Set with explicit expiry (alias)."""
|
|
return self.set(key, value, ttl)
|
|
|
|
def keys(self, pattern: str) -> List[str]:
|
|
"""Get keys matching pattern."""
|
|
self._calls.append({'method': 'keys', 'pattern': pattern})
|
|
return [k for k in self._cache.keys() if fnmatch.fnmatch(k, pattern)]
|
|
|
|
def delete(self, *keys: str) -> int:
|
|
"""Delete specific keys."""
|
|
self._calls.append({'method': 'delete', 'keys': list(keys)})
|
|
deleted = 0
|
|
for key in keys:
|
|
if key in self._cache:
|
|
del self._cache[key]
|
|
if key in self._expiry:
|
|
del self._expiry[key]
|
|
deleted += 1
|
|
return deleted
|
|
|
|
def invalidate_pattern(self, pattern: str) -> int:
|
|
"""Delete all keys matching pattern."""
|
|
keys = self.keys(pattern)
|
|
return self.delete(*keys)
|
|
|
|
def exists(self, key: str) -> bool:
|
|
"""Check if key exists and not expired."""
|
|
if self._is_expired(key):
|
|
return False
|
|
return key in self._cache
|
|
|
|
def clear(self):
|
|
"""Clear all cached data."""
|
|
self._cache.clear()
|
|
self._expiry.clear()
|
|
self._calls.clear()
|
|
self._hit_count = 0
|
|
self._miss_count = 0
|
|
|
|
def get_calls(self, method: Optional[str] = None) -> List[Dict]:
|
|
"""Get tracked calls."""
|
|
if method:
|
|
return [c for c in self._calls if c.get('method') == method]
|
|
return list(self._calls)
|
|
|
|
def clear_calls(self):
|
|
"""Clear call history."""
|
|
self._calls.clear()
|
|
|
|
@property
|
|
def hit_rate(self) -> float:
|
|
"""Get cache hit rate."""
|
|
total = self._hit_count + self._miss_count
|
|
if total == 0:
|
|
return 0.0
|
|
return self._hit_count / total
|
|
|
|
def assert_called_with(self, method: str, **kwargs) -> bool:
|
|
"""Assert a method was called with specific args."""
|
|
for call in self._calls:
|
|
if call.get('method') == method:
|
|
for key, value in kwargs.items():
|
|
if call.get(key) != value:
|
|
break
|
|
else:
|
|
return True
|
|
available = [c.get('method') for c in self._calls]
|
|
raise AssertionError(f"Expected {method}({kwargs}) not found. Available: {available}")
|
|
|
|
def was_called(self, method: str) -> bool:
|
|
"""Check if method was called."""
|
|
return any(c.get('method') == method for c in self._calls)
|
|
|
|
|
|
class MockCacheService:
|
|
"""Alias for EnhancedMockCache for compatibility."""
|
|
def __new__(cls):
|
|
return EnhancedMockCache()
|