#!/usr/bin/env python3 import os import logging from datetime import datetime logger = logging.getLogger(f'{__name__}.reset_postgres') def reset_postgres_database(): """Complete reset of PostgreSQL database for testing""" # Set PostgreSQL environment os.environ['DATABASE_TYPE'] = 'postgresql' # Use environment variables with fallbacks db_name = os.environ.get('SBA_DATABASE', 'sba_master') db_user = os.environ.get('SBA_DB_USER', 'sba_admin') db_password = os.environ.get('SBA_DB_USER_PASSWORD', 'sba_dev_password_2024') db_host = os.environ.get('POSTGRES_HOST', 'localhost') db_port = int(os.environ.get('POSTGRES_PORT', '5432')) # Direct PostgreSQL connection (avoid db_engine complications) from peewee import PostgresqlDatabase try: logger.info(f"Connecting to PostgreSQL at {db_host}:{db_port}...") db = PostgresqlDatabase( db_name, user=db_user, password=db_password, host=db_host, port=db_port ) db.connect() # Get list of all tables in public schema - use simpler query logger.info("Querying for existing tables...") try: # Use psql-style \dt query converted to SQL tables_result = db.execute_sql(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' """).fetchall() logger.info(f"Query returned {len(tables_result)} results") table_names = [] for table_row in tables_result: table_name = table_row[0] table_names.append(table_name) logger.info(f"Found table: {table_name}") except Exception as query_error: logger.error(f"Query execution error: {query_error}") raise if not table_names: logger.info("No tables found - database already clean") db.close() return True logger.info(f"Found {len(table_names)} tables to drop") # Disable foreign key checks and drop all tables for table_name in table_names: logger.info(f" Dropping table: {table_name}") db.execute_sql(f'DROP TABLE IF EXISTS "{table_name}" CASCADE') # Reset sequences (auto-increment counters) sequences_query = """ SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema = 'public' """ sequences_result = db.execute_sql(sequences_query).fetchall() for seq in sequences_result: if seq and len(seq) > 0: seq_name = seq[0] logger.info(f" Resetting sequence: {seq_name}") db.execute_sql(f'DROP SEQUENCE IF EXISTS "{seq_name}" CASCADE') db.close() logger.info("✓ PostgreSQL database reset complete") return True except Exception as e: logger.error(f"✗ Database reset failed: {e}") try: db.close() except: pass return False def main(): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger.info("=== PostgreSQL Database Reset ===") success = reset_postgres_database() if success: logger.info("🗑️ Database reset successful - ready for fresh migration") else: logger.error("❌ Database reset failed") return 0 if success else 1 if __name__ == "__main__": exit(main())