""" Enhanced Mock Implementations for Testing Provides comprehensive in-memory mocks for full test coverage. """ from typing import List, Dict, Any, Optional, Callable from collections import defaultdict import time import fnmatch class MockQueryResult: """Enhanced mock query result that supports chaining and complex queries.""" def __init__(self, items: List[Dict[str, Any]]): self._items = list(items) self._original_items = list(items) self._filters: List[Callable] = [] self._order_by_field = None self._order_by_desc = False def where(self, *conditions) -> 'MockQueryResult': """Apply WHERE conditions.""" result = MockQueryResult(self._original_items.copy()) result._filters = self._filters.copy() def apply_filter(item): for condition in conditions: if callable(condition): if not condition(item): return False elif isinstance(condition, tuple): field, op, value = condition item_val = item.get(field) if op == '<<': # IN if item_val not in value: return False elif op == '==': if item_val != value: return False elif op == '!=': if item_val == value: return False return True filtered = [i for i in self._items if apply_filter(i)] result._items = filtered return result def order_by(self, *fields) -> 'MockQueryResult': """Order by fields.""" result = MockQueryResult(self._items.copy()) def get_sort_key(item): values = [] for field in fields: neg = False if hasattr(field, '__neg__'): field = -field neg = True val = item.get(str(field), 0) if isinstance(val, (int, float)): values.append(-val if neg else val) else: values.append(val) return tuple(values) result._items.sort(key=get_sort_key) return result def count(self) -> int: return len(self._items) def __iter__(self): return iter(self._items) def __len__(self): return len(self._items) def __getitem__(self, index): return self._items[index] class EnhancedMockRepository: """Enhanced mock repository with full CRUD support.""" def __init__(self, name: str = "entity"): self._data: Dict[int, Dict] = {} self._id_counter = 1 self._name = name self._last_query = None def _make_id(self, item: Dict) -> int: """Generate or use existing ID.""" if 'id' not in item or item['id'] is None: item['id'] = self._id_counter self._id_counter += 1 else: # Update counter if existing ID is >= current counter if item['id'] >= self._id_counter: self._id_counter = item['id'] + 1 return item['id'] def select_season(self, season: int) -> MockQueryResult: """Get all items for a season.""" items = [v for v in self._data.values() if v.get('season') == season] return MockQueryResult(items) def get_by_id(self, entity_id: int) -> Optional[Dict]: """Get item by ID.""" return self._data.get(entity_id) def get_or_none(self, *conditions, **field_conditions) -> Optional[Dict]: """Get first item matching conditions.""" # Convert field_conditions to conditions converted_conditions = list(conditions) for field, value in field_conditions.items(): converted_conditions.append(lambda item, f=field, v=value: item.get(f) == v) for item in self._data.values(): if self._matches(item, converted_conditions): return item return None def _matches(self, item: Dict, conditions) -> bool: """Check if item matches conditions.""" for condition in conditions: if callable(condition): if not condition(item): return False return True def update(self, data: Dict, *conditions) -> int: """Update items matching conditions.""" updated = 0 for item in self._data.values(): if self._matches(item, conditions): for key, value in data.items(): item[key] = value updated += 1 return updated def insert_many(self, data: List[Dict]) -> int: """Insert multiple items.""" count = 0 for item in data: self.add(item) count += 1 return count def delete_by_id(self, entity_id: int) -> int: """Delete item by ID.""" if entity_id in self._data: del self._data[entity_id] return 1 return 0 def add(self, item: Dict) -> Dict: """Add item to repository.""" self._make_id(item) self._data[item['id']] = item return item def clear(self): """Clear all data.""" self._data.clear() self._id_counter = 1 def all(self) -> List[Dict]: """Get all items.""" return list(self._data.values()) def count(self) -> int: """Count all items.""" return len(self._data) class MockPlayerRepository(EnhancedMockRepository): """In-memory mock of player database.""" def __init__(self): super().__init__("player") def add_player(self, player: Dict) -> Dict: """Add player with validation.""" return self.add(player) def select_season(self, season: int) -> MockQueryResult: """Get all players for a season (0 = all seasons).""" if season == 0: # Return all players items = list(self._data.values()) else: items = [p for p in self._data.values() if p.get('season') == season] return MockQueryResult(items) def update(self, data: Dict, player_id: int) -> int: """Update player by ID (matches RealPlayerRepository signature).""" if player_id in self._data: for key, value in data.items(): self._data[player_id][key] = value return 1 return 0 class MockTeamRepository(EnhancedMockRepository): """In-memory mock of team database.""" def __init__(self): super().__init__("team") def add_team(self, team: Dict) -> Dict: """Add team with validation.""" return self.add(team) def select_season(self, season: int) -> MockQueryResult: """Get all teams for a season.""" if season == 0: # Return all teams items = list(self._data.values()) else: items = [t for t in self._data.values() if t.get('season') == season] return MockQueryResult(items) def update(self, data: Dict, team_id: int) -> int: """Update team by ID (matches RealTeamRepository signature).""" if team_id in self._data: for key, value in data.items(): self._data[team_id][key] = value return 1 return 0 class EnhancedMockCache: """Enhanced mock cache with call tracking and TTL support.""" def __init__(self): self._cache: Dict[str, str] = {} self._expiry: Dict[str, float] = {} self._calls: List[Dict] = [] self._hit_count = 0 self._miss_count = 0 def _is_expired(self, key: str) -> bool: """Check if key is expired.""" if key not in self._expiry: return False if time.time() < self._expiry[key]: return False # Clean up expired key del self._cache[key] del self._expiry[key] return True def get(self, key: str) -> Optional[str]: """Get cached value.""" self._calls.append({'method': 'get', 'key': key}) if self._is_expired(key): self._miss_count += 1 return None if key in self._cache: self._hit_count += 1 return self._cache[key] self._miss_count += 1 return None def set(self, key: str, value: str, ttl: int = 300) -> bool: """Set cached value with TTL.""" self._calls.append({ 'method': 'set', 'key': key, 'value': value[:200] if isinstance(value, str) else str(value)[:200], 'ttl': ttl }) self._cache[key] = value self._expiry[key] = time.time() + ttl return True def setex(self, key: str, ttl: int, value: str) -> bool: """Set with explicit expiry (alias).""" return self.set(key, value, ttl) def keys(self, pattern: str) -> List[str]: """Get keys matching pattern.""" self._calls.append({'method': 'keys', 'pattern': pattern}) return [k for k in self._cache.keys() if fnmatch.fnmatch(k, pattern)] def delete(self, *keys: str) -> int: """Delete specific keys.""" self._calls.append({'method': 'delete', 'keys': list(keys)}) deleted = 0 for key in keys: if key in self._cache: del self._cache[key] if key in self._expiry: del self._expiry[key] deleted += 1 return deleted def invalidate_pattern(self, pattern: str) -> int: """Delete all keys matching pattern.""" keys = self.keys(pattern) return self.delete(*keys) def exists(self, key: str) -> bool: """Check if key exists and not expired.""" if self._is_expired(key): return False return key in self._cache def clear(self): """Clear all cached data.""" self._cache.clear() self._expiry.clear() self._calls.clear() self._hit_count = 0 self._miss_count = 0 def get_calls(self, method: Optional[str] = None) -> List[Dict]: """Get tracked calls.""" if method: return [c for c in self._calls if c.get('method') == method] return list(self._calls) def clear_calls(self): """Clear call history.""" self._calls.clear() @property def hit_rate(self) -> float: """Get cache hit rate.""" total = self._hit_count + self._miss_count if total == 0: return 0.0 return self._hit_count / total def assert_called_with(self, method: str, **kwargs) -> bool: """Assert a method was called with specific args.""" for call in self._calls: if call.get('method') == method: for key, value in kwargs.items(): if call.get(key) != value: break else: return True available = [c.get('method') for c in self._calls] raise AssertionError(f"Expected {method}({kwargs}) not found. Available: {available}") def was_called(self, method: str) -> bool: """Check if method was called.""" return any(c.get('method') == method for c in self._calls) class MockCacheService: """Alias for EnhancedMockCache for compatibility.""" def __new__(cls): return EnhancedMockCache()