51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
"""Creates and restores the demo database snapshot used for hourly resets."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
|
|
SNAPSHOT_PATH = Path(os.environ.get("DEMO_SNAPSHOT_PATH", "/app/demo_data/demo_snapshot.sql.gz"))
|
|
|
|
|
|
async def create_snapshot() -> None:
|
|
"""pg_dump the current DB to SNAPSHOT_PATH (gzip compressed)."""
|
|
db_url = os.environ.get("DATABASE_URL", "")
|
|
pg_url = db_url.replace("postgresql+asyncpg", "postgresql")
|
|
|
|
proc = await asyncio.create_subprocess_shell(
|
|
f'pg_dump --no-owner --no-acl "{pg_url}" | gzip > "{SNAPSHOT_PATH}"',
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
_, err = await proc.communicate()
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"Snapshot failed: {err.decode()}")
|
|
|
|
|
|
async def restore_snapshot() -> None:
|
|
"""Restore DB from SNAPSHOT_PATH, dropping and recreating all user data."""
|
|
if not SNAPSHOT_PATH.exists():
|
|
raise FileNotFoundError(f"Snapshot not found: {SNAPSHOT_PATH}")
|
|
|
|
db_url = os.environ.get("DATABASE_URL", "")
|
|
pg_url = db_url.replace("postgresql+asyncpg", "postgresql")
|
|
|
|
# Truncate all user-data tables in dependency order, then restore
|
|
truncate_sql = """
|
|
TRUNCATE TABLE
|
|
manual_cgt_disposals, payslips, tax_profiles, tax_rate_configs,
|
|
investment_transactions, investment_holdings, assets,
|
|
audit_logs, net_worth_snapshots,
|
|
transactions, budgets, accounts, categories,
|
|
sessions, users
|
|
RESTART IDENTITY CASCADE;
|
|
"""
|
|
|
|
proc = await asyncio.create_subprocess_shell(
|
|
f'gunzip -c "{SNAPSHOT_PATH}" | psql --single-transaction -v ON_ERROR_STOP=1 "{pg_url}"',
|
|
stderr=asyncio.subprocess.PIPE,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
)
|
|
_, err = await proc.communicate()
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"Restore failed: {err.decode()}")
|