feat: add migration tracking system (#81) #96

Open
Claude wants to merge 1 commits from issue/81-add-migration-tracking-system into main
3 changed files with 97 additions and 1 deletions

1
.gitignore vendored
View File

@ -55,7 +55,6 @@ Include/
pyvenv.cfg
db_engine.py
main.py
migrations.py
db_engine.py
sba_master.db
db_engine.py

88
migrations.py Normal file
View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""Apply pending SQL migrations and record them in schema_versions.
Usage:
python migrations.py
Connects to PostgreSQL using the same environment variables as the API:
POSTGRES_DB (default: sba_master)
POSTGRES_USER (default: sba_admin)
POSTGRES_PASSWORD (required)
POSTGRES_HOST (default: sba_postgres)
POSTGRES_PORT (default: 5432)
On first run against an existing database, all migrations will be applied.
All migration files use IF NOT EXISTS guards so re-applying is safe.
"""
import os
import sys
from pathlib import Path
import psycopg2
MIGRATIONS_DIR = Path(__file__).parent / "migrations"
_CREATE_SCHEMA_VERSIONS = """
CREATE TABLE IF NOT EXISTS schema_versions (
filename VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
);
"""
def _get_connection():
password = os.environ.get("POSTGRES_PASSWORD")
if password is None:
raise RuntimeError("POSTGRES_PASSWORD environment variable is not set")
return psycopg2.connect(
dbname=os.environ.get("POSTGRES_DB", "sba_master"),
user=os.environ.get("POSTGRES_USER", "sba_admin"),
password=password,
host=os.environ.get("POSTGRES_HOST", "sba_postgres"),
port=int(os.environ.get("POSTGRES_PORT", "5432")),
)
def main():
conn = _get_connection()
try:
with conn:
with conn.cursor() as cur:
cur.execute(_CREATE_SCHEMA_VERSIONS)
with conn.cursor() as cur:
cur.execute("SELECT filename FROM schema_versions")
applied = {row[0] for row in cur.fetchall()}
migration_files = sorted(MIGRATIONS_DIR.glob("*.sql"))
pending = [f for f in migration_files if f.name not in applied]
if not pending:
print("No pending migrations.")
return
for migration_file in pending:
print(f"Applying {migration_file.name} ...", end=" ", flush=True)
sql = migration_file.read_text()
with conn:
with conn.cursor() as cur:
cur.execute(sql)
cur.execute(
"INSERT INTO schema_versions (filename) VALUES (%s)",
(migration_file.name,),
)
print("done")
print(f"\nApplied {len(pending)} migration(s).")
finally:
conn.close()
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

View File

@ -0,0 +1,9 @@
-- Migration: Add schema_versions table for migration tracking
-- Date: 2026-03-27
-- Description: Creates a table to record which SQL migrations have been applied,
-- preventing double-application and missed migrations across environments.
CREATE TABLE IF NOT EXISTS schema_versions (
filename VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
);