"""Creates and restores the demo database snapshot used for hourly resets.""" from __future__ import annotations import asyncio import os from pathlib import Path SNAPSHOT_PATH = Path(os.environ.get("DEMO_SNAPSHOT_PATH", "/app/demo_data/demo_snapshot.sql.gz")) async def create_snapshot() -> None: """pg_dump the current DB to SNAPSHOT_PATH (gzip compressed).""" db_url = os.environ.get("DATABASE_URL", "") pg_url = db_url.replace("postgresql+asyncpg", "postgresql") proc = await asyncio.create_subprocess_shell( f'pg_dump --no-owner --no-acl "{pg_url}" | gzip > "{SNAPSHOT_PATH}"', stderr=asyncio.subprocess.PIPE, ) _, err = await proc.communicate() if proc.returncode != 0: raise RuntimeError(f"Snapshot failed: {err.decode()}") async def restore_snapshot() -> None: """Restore DB from SNAPSHOT_PATH, dropping and recreating all user data.""" if not SNAPSHOT_PATH.exists(): raise FileNotFoundError(f"Snapshot not found: {SNAPSHOT_PATH}") db_url = os.environ.get("DATABASE_URL", "") pg_url = db_url.replace("postgresql+asyncpg", "postgresql") # Truncate all user-data tables in dependency order, then restore truncate_sql = """ TRUNCATE TABLE manual_cgt_disposals, payslips, tax_profiles, tax_rate_configs, investment_transactions, investment_holdings, assets, audit_logs, net_worth_snapshots, transactions, budgets, accounts, categories, sessions, users RESTART IDENTITY CASCADE; """ proc = await asyncio.create_subprocess_shell( f'gunzip -c "{SNAPSHOT_PATH}" | psql --single-transaction -v ON_ERROR_STOP=1 "{pg_url}"', stderr=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, ) _, err = await proc.communicate() if proc.returncode != 0: raise RuntimeError(f"Restore failed: {err.decode()}")