113 lines
3.7 KiB
Python
113 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import logging
|
|
from datetime import datetime
|
|
|
|
logger = logging.getLogger(f'{__name__}.reset_postgres')
|
|
|
|
def reset_postgres_database():
|
|
"""Complete reset of PostgreSQL database for testing"""
|
|
|
|
# Set PostgreSQL environment
|
|
os.environ['DATABASE_TYPE'] = 'postgresql'
|
|
|
|
# Use environment variables with fallbacks
|
|
db_name = os.environ.get('SBA_DATABASE', 'sba_master')
|
|
db_user = os.environ.get('SBA_DB_USER', 'sba_admin')
|
|
db_password = os.environ.get('SBA_DB_USER_PASSWORD', 'sba_dev_password_2024')
|
|
db_host = os.environ.get('POSTGRES_HOST', 'localhost')
|
|
db_port = int(os.environ.get('POSTGRES_PORT', '5432'))
|
|
|
|
# Direct PostgreSQL connection (avoid db_engine complications)
|
|
from peewee import PostgresqlDatabase
|
|
|
|
try:
|
|
logger.info(f"Connecting to PostgreSQL at {db_host}:{db_port}...")
|
|
db = PostgresqlDatabase(
|
|
db_name,
|
|
user=db_user,
|
|
password=db_password,
|
|
host=db_host,
|
|
port=db_port
|
|
)
|
|
db.connect()
|
|
|
|
# Get list of all tables in public schema - use simpler query
|
|
logger.info("Querying for existing tables...")
|
|
try:
|
|
# Use psql-style \dt query converted to SQL
|
|
tables_result = db.execute_sql("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_type = 'BASE TABLE'
|
|
""").fetchall()
|
|
|
|
logger.info(f"Query returned {len(tables_result)} results")
|
|
|
|
table_names = []
|
|
for table_row in tables_result:
|
|
table_name = table_row[0]
|
|
table_names.append(table_name)
|
|
logger.info(f"Found table: {table_name}")
|
|
|
|
except Exception as query_error:
|
|
logger.error(f"Query execution error: {query_error}")
|
|
raise
|
|
|
|
if not table_names:
|
|
logger.info("No tables found - database already clean")
|
|
db.close()
|
|
return True
|
|
|
|
logger.info(f"Found {len(table_names)} tables to drop")
|
|
|
|
# Disable foreign key checks and drop all tables
|
|
for table_name in table_names:
|
|
logger.info(f" Dropping table: {table_name}")
|
|
db.execute_sql(f'DROP TABLE IF EXISTS "{table_name}" CASCADE')
|
|
|
|
# Reset sequences (auto-increment counters)
|
|
sequences_query = """
|
|
SELECT sequence_name FROM information_schema.sequences
|
|
WHERE sequence_schema = 'public'
|
|
"""
|
|
|
|
sequences_result = db.execute_sql(sequences_query).fetchall()
|
|
for seq in sequences_result:
|
|
if seq and len(seq) > 0:
|
|
seq_name = seq[0]
|
|
logger.info(f" Resetting sequence: {seq_name}")
|
|
db.execute_sql(f'DROP SEQUENCE IF EXISTS "{seq_name}" CASCADE')
|
|
|
|
db.close()
|
|
logger.info("✓ PostgreSQL database reset complete")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"✗ Database reset failed: {e}")
|
|
try:
|
|
db.close()
|
|
except:
|
|
pass
|
|
return False
|
|
|
|
def main():
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
logger.info("=== PostgreSQL Database Reset ===")
|
|
success = reset_postgres_database()
|
|
|
|
if success:
|
|
logger.info("🗑️ Database reset successful - ready for fresh migration")
|
|
else:
|
|
logger.error("❌ Database reset failed")
|
|
|
|
return 0 if success else 1
|
|
|
|
if __name__ == "__main__":
|
|
exit(main()) |