Complete Phase 3, Phase 5 polish and hardening
Phase 3 — Investments: - Multi-currency support: holdings track purchase currency, FX rates convert to base for totals - Capital gains report using UK Section 104 pool method, grouped by tax year - Capital Gains tab added to Reports page Phase 5 — Polish & Hardening: - Mobile-responsive layout: bottom nav, sidebar hidden on mobile, logo in TopBar, compact header buttons, hover-only actions now always visible on touch - Backup system: encrypted GPG backups via backup.sh, nightly scheduler job, admin API (list/trigger/download/restore), Settings UI with drag-to-restore confirmation - Docker entrypoint with gosu privilege drop to fix bind-mount ownership on fresh deployments - OWASP fixes: refresh token now bound to its session (new refresh_token_hash column + migration), CSRF secure flag tied to environment, IP-level rate limiting on login, TOTPEnableRequest Pydantic schema replaces raw dict - AES-256-GCM key rotation script (rotate_keys.py) with dry-run mode and atomic DB transaction - CLAUDE.md added for AI-assisted development context - README updated: correct reverse proxy port, accurate backup/restore commands, key rotation instructions Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
74e57a35c0
commit
fe4e69b9ad
40 changed files with 2079 additions and 127 deletions
|
|
@ -1,6 +1,10 @@
|
|||
FROM python:3.12-slim AS base
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libmagic1 \
|
||||
postgresql-client \
|
||||
gnupg \
|
||||
gzip \
|
||||
gosu \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN pip install --no-cache-dir uv
|
||||
WORKDIR /app
|
||||
|
|
@ -13,7 +17,13 @@ FROM deps AS production
|
|||
COPY app/ ./app/
|
||||
COPY alembic/ ./alembic/
|
||||
COPY alembic.ini ./
|
||||
RUN useradd -r -s /bin/false -u 1001 appuser && chown -R appuser /app && mkdir -p /app/uploads && chown appuser /app/uploads
|
||||
USER appuser
|
||||
COPY scripts/ ./scripts/
|
||||
RUN useradd -r -s /bin/false -u 1001 appuser \
|
||||
&& chown -R appuser /app \
|
||||
&& mkdir -p /app/uploads /app/backups \
|
||||
&& chown appuser /app/uploads /app/backups
|
||||
COPY scripts/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
EXPOSE 8000
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
CMD ["sh", "-c", "python -m alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 2 --proxy-headers"]
|
||||
|
|
|
|||
23
backend/alembic/versions/0002_refresh_token_hash.py
Normal file
23
backend/alembic/versions/0002_refresh_token_hash.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""add refresh_token_hash to sessions
|
||||
|
||||
Revision ID: 0002
|
||||
Revises: 0001
|
||||
Create Date: 2026-04-22
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0002"
|
||||
down_revision = "0001"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("sessions", sa.Column("refresh_token_hash", sa.Text, nullable=True))
|
||||
op.create_index("ix_sessions_refresh_token_hash", "sessions", ["refresh_token_hash"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_sessions_refresh_token_hash", table_name="sessions")
|
||||
op.drop_column("sessions", "refresh_token_hash")
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import auth, users, accounts, categories, transactions, budgets, reports, investments, predictions
|
||||
from app.api.v1 import auth, users, accounts, categories, transactions, budgets, reports, investments, predictions, admin
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
|
|
@ -12,3 +12,4 @@ router.include_router(budgets.router)
|
|||
router.include_router(reports.router)
|
||||
router.include_router(investments.router)
|
||||
router.include_router(predictions.router)
|
||||
router.include_router(admin.router)
|
||||
|
|
|
|||
152
backend/app/api/v1/admin.py
Normal file
152
backend/app/api/v1/admin.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.dependencies import get_current_user
|
||||
from app.db.models.user import User
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
BACKUP_DIR = Path(os.environ.get("BACKUP_DIR", "/app/backups"))
|
||||
BACKUP_PATTERN = re.compile(r"^\d{8}_\d{6}\.sql\.gz\.gpg$")
|
||||
|
||||
|
||||
class BackupFile(BaseModel):
|
||||
filename: str
|
||||
size_bytes: int
|
||||
created_at: str
|
||||
|
||||
|
||||
class BackupResult(BaseModel):
|
||||
ok: bool
|
||||
message: str
|
||||
|
||||
|
||||
def _list_backup_files() -> list[BackupFile]:
|
||||
if not BACKUP_DIR.exists():
|
||||
return []
|
||||
files = []
|
||||
for f in sorted(BACKUP_DIR.glob("*.sql.gz.gpg"), reverse=True):
|
||||
stat = f.stat()
|
||||
files.append(BackupFile(
|
||||
filename=f.name,
|
||||
size_bytes=stat.st_size,
|
||||
created_at=datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
))
|
||||
return files
|
||||
|
||||
|
||||
@router.get("/backups", response_model=list[BackupFile])
|
||||
async def list_backups(current_user: User = Depends(get_current_user)):
|
||||
return _list_backup_files()
|
||||
|
||||
|
||||
@router.post("/backup", response_model=BackupResult)
|
||||
async def trigger_backup(current_user: User = Depends(get_current_user)):
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"bash", "/app/scripts/backup.sh",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode().strip() if stdout else ""
|
||||
if proc.returncode == 0:
|
||||
return BackupResult(ok=True, message=output or "Backup completed")
|
||||
raise HTTPException(status_code=500, detail=f"Backup failed: {output}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/backups/{filename}")
|
||||
async def download_backup(
|
||||
filename: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not BACKUP_PATTERN.match(filename):
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
path = BACKUP_DIR / filename
|
||||
if not path.exists():
|
||||
raise HTTPException(status_code=404, detail="Backup not found")
|
||||
return FileResponse(
|
||||
path=str(path),
|
||||
filename=filename,
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/restore/{filename}", response_model=BackupResult)
|
||||
async def restore_backup(
|
||||
filename: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not BACKUP_PATTERN.match(filename):
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
path = BACKUP_DIR / filename
|
||||
if not path.exists():
|
||||
raise HTTPException(status_code=404, detail="Backup not found")
|
||||
|
||||
passphrase = os.environ.get("BACKUP_PASSPHRASE", "")
|
||||
if not passphrase:
|
||||
raise HTTPException(status_code=500, detail="BACKUP_PASSPHRASE not configured")
|
||||
|
||||
database_url = os.environ.get("DATABASE_URL", "")
|
||||
pg_url = database_url.replace("postgresql+asyncpg", "postgresql")
|
||||
|
||||
# Ensure GPG has a writable home (appuser has no real home directory)
|
||||
gnupg_home = "/tmp/.gnupg"
|
||||
os.makedirs(gnupg_home, mode=0o700, exist_ok=True)
|
||||
gpg_env = {**os.environ, "GNUPGHOME": gnupg_home}
|
||||
|
||||
try:
|
||||
# Decrypt and decompress into psql non-interactively
|
||||
gpg_proc = await asyncio.create_subprocess_exec(
|
||||
"gpg", "--batch", "--yes", "--no-symkey-cache",
|
||||
"--pinentry-mode", "loopback",
|
||||
"--decrypt", "--passphrase", passphrase,
|
||||
str(path),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=gpg_env,
|
||||
)
|
||||
gpg_out, gpg_err = await gpg_proc.communicate()
|
||||
if gpg_proc.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Decryption failed: {gpg_err.decode()}")
|
||||
|
||||
gunzip_proc = await asyncio.create_subprocess_exec(
|
||||
"gunzip",
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
sql_data, gunzip_err = await gunzip_proc.communicate(input=gpg_out)
|
||||
if gunzip_proc.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Decompression failed: {gunzip_err.decode()}")
|
||||
|
||||
psql_proc = await asyncio.create_subprocess_exec(
|
||||
"psql",
|
||||
"--single-transaction",
|
||||
"-v", "ON_ERROR_STOP=1",
|
||||
pg_url,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
psql_out, psql_err = await psql_proc.communicate(input=sql_data)
|
||||
if psql_proc.returncode != 0:
|
||||
detail = (psql_err.decode().strip() or psql_out.decode().strip() or "psql exited with no output")
|
||||
raise HTTPException(status_code=500, detail=f"Restore failed: {detail}")
|
||||
|
||||
return BackupResult(ok=True, message=f"Restored from {filename}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.core.audit import write_audit
|
||||
from app.core.rate_limiter import is_rate_limited
|
||||
from app.core.security import create_refresh_token, decode_token, generate_csrf_token, hash_token
|
||||
from app.schemas.auth import TOTPEnableRequest
|
||||
from app.dependencies import get_current_user, get_db, get_redis
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
|
|
@ -96,6 +97,11 @@ async def login(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
ip = _ip(request) or "unknown"
|
||||
limited, _ = await is_rate_limited(redis, f"rate:login:{ip}", limit=20, window_seconds=60)
|
||||
if limited:
|
||||
raise HTTPException(status_code=429, detail="Too many login attempts — try again shortly")
|
||||
|
||||
try:
|
||||
user, access_token, refresh_token = await authenticate_user(
|
||||
db, redis, body.email, body.password, _ip(request), _ua(request)
|
||||
|
|
@ -171,24 +177,27 @@ async def refresh_token(
|
|||
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
now = datetime.now(timezone.utc)
|
||||
refresh_hash = hash_token(token)
|
||||
|
||||
# Find and update session
|
||||
# Find the specific session this refresh token was issued for
|
||||
result = await db.execute(
|
||||
select(Session).where(
|
||||
Session.user_id == user_id,
|
||||
Session.refresh_token_hash == refresh_hash,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.expires_at > now,
|
||||
)
|
||||
)
|
||||
session = result.scalars().first()
|
||||
if not session:
|
||||
raise HTTPException(status_code=401, detail="Session not found")
|
||||
raise HTTPException(status_code=401, detail="Session not found or refresh token already used")
|
||||
|
||||
new_access = create_access_token(str(user_id))
|
||||
new_refresh = create_refresh_token(str(user_id))
|
||||
|
||||
# Rotate session token hash
|
||||
# Rotate both token hashes — old refresh token is now invalid
|
||||
session.token_hash = hash_token(new_access)
|
||||
session.refresh_token_hash = hash_token(new_refresh)
|
||||
session.last_active_at = now
|
||||
await db.commit()
|
||||
|
||||
|
|
@ -305,17 +314,13 @@ async def totp_verify(
|
|||
|
||||
@router.post("/totp/enable", status_code=200)
|
||||
async def totp_enable(
|
||||
body: dict,
|
||||
body: TOTPEnableRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
secret = body.get("secret")
|
||||
code = body.get("code")
|
||||
if not secret or not code:
|
||||
raise HTTPException(status_code=422, detail="secret and code required")
|
||||
try:
|
||||
await enable_totp(user, db, secret, code)
|
||||
await enable_totp(user, db, body.secret, body.code)
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
await write_audit(db, user_id=user.id, action="totp_enable", ip_address=_ip(request))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import uuid
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select, delete as sa_delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
|
|
@ -9,6 +11,7 @@ from app.db.models.user import User
|
|||
from app.schemas.investment import (
|
||||
AssetSearch,
|
||||
AssetPricePoint,
|
||||
CapitalGainsReport,
|
||||
HoldingCreate,
|
||||
HoldingResponse,
|
||||
InvestmentTxnCreate,
|
||||
|
|
@ -29,7 +32,7 @@ async def get_portfolio(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await investment_service.get_portfolio(db, current_user.id)
|
||||
return await investment_service.get_portfolio(db, current_user.id, current_user.base_currency)
|
||||
|
||||
|
||||
@router.get("/investments/performance", response_model=PerformanceMetrics)
|
||||
|
|
@ -37,7 +40,15 @@ async def get_performance(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await investment_service.get_performance(db, current_user.id)
|
||||
return await investment_service.get_performance(db, current_user.id, current_user.base_currency)
|
||||
|
||||
|
||||
@router.get("/investments/capital-gains", response_model=CapitalGainsReport)
|
||||
async def get_capital_gains(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await investment_service.get_capital_gains(db, current_user.id, current_user.base_currency)
|
||||
|
||||
|
||||
# ── Holdings ───────────────────────────────────────────────────────────────
|
||||
|
|
@ -61,6 +72,27 @@ async def create_holding(
|
|||
return investment_service._holding_to_response(holding, asset)
|
||||
|
||||
|
||||
@router.patch("/investments/holdings/{holding_id}", response_model=HoldingResponse)
|
||||
async def update_holding(
|
||||
holding_id: uuid.UUID,
|
||||
quantity: Decimal = Body(...),
|
||||
avg_cost_basis: Decimal = Body(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
holding = await investment_service.get_holding(db, current_user.id, holding_id)
|
||||
if not holding:
|
||||
raise HTTPException(status_code=404, detail="Holding not found")
|
||||
holding.quantity = quantity
|
||||
holding.avg_cost_basis = avg_cost_basis
|
||||
await db.commit()
|
||||
await db.refresh(holding)
|
||||
from app.db.models.asset import Asset
|
||||
result = await db.execute(select(Asset).where(Asset.id == holding.asset_id))
|
||||
asset = result.scalar_one_or_none()
|
||||
return investment_service._holding_to_response(holding, asset)
|
||||
|
||||
|
||||
@router.delete("/investments/holdings/{holding_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_holding(
|
||||
holding_id: uuid.UUID,
|
||||
|
|
@ -70,6 +102,8 @@ async def delete_holding(
|
|||
holding = await investment_service.get_holding(db, current_user.id, holding_id)
|
||||
if not holding:
|
||||
raise HTTPException(status_code=404, detail="Holding not found")
|
||||
from app.db.models.investment_transaction import InvestmentTransaction
|
||||
await db.execute(sa_delete(InvestmentTransaction).where(InvestmentTransaction.holding_id == holding_id))
|
||||
await db.delete(holding)
|
||||
await db.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.schemas.report import (
|
|||
CategoryBreakdownReport,
|
||||
IncomeExpenseReport,
|
||||
NetWorthReport,
|
||||
SavingsRateReport,
|
||||
SpendingTrendsReport,
|
||||
)
|
||||
from app.services import report_service
|
||||
|
|
@ -83,6 +84,15 @@ async def spending_trends(
|
|||
return await report_service.get_spending_trends(db, current_user.id, months)
|
||||
|
||||
|
||||
@router.get("/savings-rate", response_model=SavingsRateReport)
|
||||
async def savings_rate(
|
||||
months: int = Query(default=12, ge=1, le=60),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await report_service.get_savings_rate_report(db, current_user.id, months)
|
||||
|
||||
|
||||
@router.get("/balance-sheet", response_model=BalanceSheetReport)
|
||||
async def balance_sheet(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ async def get_transactions(
|
|||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
search: str | None = None,
|
||||
is_recurring: bool | None = None,
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
|
|
@ -62,6 +63,7 @@ async def get_transactions(
|
|||
date_from=date.fromisoformat(date_from) if date_from else None,
|
||||
date_to=date.fromisoformat(date_to) if date_to else None,
|
||||
search=search,
|
||||
is_recurring=is_recurring,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from fastapi import Request, Response
|
|||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||||
|
||||
SECURITY_HEADERS = {
|
||||
|
|
@ -55,7 +57,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||
"csrf_token", token,
|
||||
httponly=False, # must be readable by JS
|
||||
samesite="lax",
|
||||
secure=False, # set True if TLS is terminated at this service
|
||||
secure=not get_settings().is_development,
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
@ -63,7 +65,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||
response = await call_next(request)
|
||||
if not existing_csrf:
|
||||
token = str(uuid.uuid4())
|
||||
response.set_cookie("csrf_token", token, httponly=False, samesite="lax", secure=False)
|
||||
response.set_cookie("csrf_token", token, httponly=False, samesite="lax", secure=not get_settings().is_development)
|
||||
return response
|
||||
|
||||
if request.url.path in {"/api/v1/auth/login", "/api/v1/auth/login/totp"}:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ class Session(Base):
|
|||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
token_hash: Mapped[str] = mapped_column(Text, unique=True, nullable=False, index=True)
|
||||
refresh_token_hash: Mapped[str | None] = mapped_column(Text, nullable=True, index=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(INET, nullable=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,11 @@ class TOTPVerifyRequest(BaseModel):
|
|||
code: str
|
||||
|
||||
|
||||
class TOTPEnableRequest(BaseModel):
|
||||
secret: str
|
||||
code: str
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
id: uuid.UUID
|
||||
ip_address: str | None
|
||||
|
|
|
|||
|
|
@ -101,3 +101,28 @@ class PerformanceMetrics(BaseModel):
|
|||
total_return: Decimal
|
||||
total_return_pct: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class CapitalGainsDisposal(BaseModel):
|
||||
date: DateType
|
||||
symbol: str
|
||||
asset_name: str
|
||||
quantity: Decimal
|
||||
proceeds: Decimal
|
||||
cost: Decimal
|
||||
gain: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class TaxYearSummary(BaseModel):
|
||||
tax_year: str
|
||||
disposals: list[CapitalGainsDisposal]
|
||||
total_proceeds: Decimal
|
||||
total_cost: Decimal
|
||||
total_gain: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class CapitalGainsReport(BaseModel):
|
||||
tax_years: list[TaxYearSummary]
|
||||
currency: str
|
||||
|
|
|
|||
|
|
@ -96,6 +96,20 @@ class SpendingTrendsReport(BaseModel):
|
|||
currency: str
|
||||
|
||||
|
||||
class SavingsRatePoint(BaseModel):
|
||||
month: str
|
||||
income: Decimal
|
||||
expenses: Decimal
|
||||
savings: Decimal
|
||||
savings_rate: Decimal
|
||||
|
||||
|
||||
class SavingsRateReport(BaseModel):
|
||||
points: list[SavingsRatePoint]
|
||||
avg_savings_rate: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class BalanceSheetAccount(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class TransactionFilter(BaseModel):
|
|||
max_amount: Decimal | None = None
|
||||
search: str | None = None
|
||||
tags: list[str] = []
|
||||
is_recurring: bool | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=50, ge=1, le=200)
|
||||
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ async def _create_session(
|
|||
session = Session(
|
||||
user_id=user.id,
|
||||
token_hash=hash_token(access_token),
|
||||
refresh_token_hash=hash_token(refresh_token),
|
||||
ip_address=ip,
|
||||
user_agent=user_agent,
|
||||
last_active_at=now,
|
||||
|
|
|
|||
|
|
@ -10,11 +10,14 @@ from app.db.models.asset_price import AssetPrice
|
|||
from app.db.models.investment_holding import InvestmentHolding
|
||||
from app.db.models.investment_transaction import InvestmentTransaction
|
||||
from app.schemas.investment import (
|
||||
CapitalGainsDisposal,
|
||||
CapitalGainsReport,
|
||||
HoldingCreate,
|
||||
HoldingResponse,
|
||||
InvestmentTxnCreate,
|
||||
PerformanceMetrics,
|
||||
PortfolioSummary,
|
||||
TaxYearSummary,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -23,10 +26,37 @@ async def _get_asset(db: AsyncSession, asset_id: uuid.UUID) -> Asset | None:
|
|||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def _holding_to_response(holding: InvestmentHolding, asset: Asset) -> HoldingResponse:
|
||||
async def _fetch_fx_rate(db: AsyncSession, from_currency: str, to_currency: str) -> Decimal:
|
||||
if from_currency == to_currency:
|
||||
return Decimal("1")
|
||||
from app.db.models.currency import ExchangeRate
|
||||
result = await db.execute(
|
||||
select(ExchangeRate)
|
||||
.where(ExchangeRate.base_currency == from_currency, ExchangeRate.quote_currency == to_currency)
|
||||
.order_by(ExchangeRate.fetched_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
er = result.scalar_one_or_none()
|
||||
return er.rate if er else Decimal("1")
|
||||
|
||||
|
||||
def _holding_to_response(
|
||||
holding: InvestmentHolding,
|
||||
asset: Asset,
|
||||
fx_rates: dict[tuple[str, str], Decimal] | None = None,
|
||||
) -> HoldingResponse:
|
||||
fx_rates = fx_rates or {}
|
||||
cost_basis_total = holding.quantity * holding.avg_cost_basis
|
||||
current_price = asset.last_price
|
||||
current_value = holding.quantity * current_price if current_price else None
|
||||
|
||||
# Convert asset's last_price to the holding's currency so P&L is comparable
|
||||
current_price_native = asset.last_price
|
||||
if current_price_native is not None and asset.currency != holding.currency:
|
||||
rate = fx_rates.get((asset.currency, holding.currency), Decimal("1"))
|
||||
current_price = current_price_native * rate
|
||||
else:
|
||||
current_price = current_price_native
|
||||
|
||||
current_value = holding.quantity * current_price if current_price is not None else None
|
||||
unrealised_gain = (current_value - cost_basis_total) if current_value is not None else None
|
||||
unrealised_gain_pct = None
|
||||
if unrealised_gain is not None and cost_basis_total > 0:
|
||||
|
|
@ -51,7 +81,7 @@ def _holding_to_response(holding: InvestmentHolding, asset: Asset) -> HoldingRes
|
|||
)
|
||||
|
||||
|
||||
async def get_portfolio(db: AsyncSession, user_id: uuid.UUID) -> PortfolioSummary:
|
||||
async def get_portfolio(db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP") -> PortfolioSummary:
|
||||
result = await db.execute(
|
||||
select(InvestmentHolding).where(
|
||||
InvestmentHolding.user_id == user_id,
|
||||
|
|
@ -60,19 +90,44 @@ async def get_portfolio(db: AsyncSession, user_id: uuid.UUID) -> PortfolioSummar
|
|||
)
|
||||
holdings = result.scalars().all()
|
||||
|
||||
# Pre-fetch all assets and determine which FX pairs we need
|
||||
assets: dict[uuid.UUID, Asset] = {}
|
||||
for h in holdings:
|
||||
if h.asset_id not in assets:
|
||||
asset = await _get_asset(db, h.asset_id)
|
||||
if asset:
|
||||
assets[h.asset_id] = asset
|
||||
|
||||
pairs_needed: set[tuple[str, str]] = set()
|
||||
for h in holdings:
|
||||
asset = assets.get(h.asset_id)
|
||||
if not asset:
|
||||
continue
|
||||
if asset.currency != h.currency:
|
||||
pairs_needed.add((asset.currency, h.currency))
|
||||
if h.currency != base_currency:
|
||||
pairs_needed.add((h.currency, base_currency))
|
||||
|
||||
fx_rates: dict[tuple[str, str], Decimal] = {}
|
||||
for from_curr, to_curr in pairs_needed:
|
||||
fx_rates[(from_curr, to_curr)] = await _fetch_fx_rate(db, from_curr, to_curr)
|
||||
|
||||
responses = []
|
||||
total_value = Decimal("0")
|
||||
total_cost = Decimal("0")
|
||||
|
||||
for h in holdings:
|
||||
asset = await _get_asset(db, h.asset_id)
|
||||
asset = assets.get(h.asset_id)
|
||||
if not asset:
|
||||
continue
|
||||
r = _holding_to_response(h, asset)
|
||||
r = _holding_to_response(h, asset, fx_rates)
|
||||
responses.append(r)
|
||||
total_cost += r.cost_basis_total
|
||||
|
||||
# Convert each holding to base_currency for the portfolio totals
|
||||
to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1")
|
||||
total_cost += r.cost_basis_total * to_base
|
||||
if r.current_value is not None:
|
||||
total_value += r.current_value
|
||||
total_value += r.current_value * to_base
|
||||
|
||||
total_gain = total_value - total_cost
|
||||
total_gain_pct = (total_gain / total_cost * 100).quantize(Decimal("0.01")) if total_cost > 0 else Decimal("0")
|
||||
|
|
@ -82,7 +137,7 @@ async def get_portfolio(db: AsyncSession, user_id: uuid.UUID) -> PortfolioSummar
|
|||
total_cost=total_cost,
|
||||
total_gain=total_gain,
|
||||
total_gain_pct=total_gain_pct,
|
||||
currency="GBP",
|
||||
currency=base_currency,
|
||||
holdings=responses,
|
||||
)
|
||||
|
||||
|
|
@ -189,18 +244,131 @@ async def list_investment_transactions(
|
|||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_performance(db: AsyncSession, user_id: uuid.UUID) -> PerformanceMetrics:
|
||||
portfolio = await get_portfolio(db, user_id)
|
||||
async def get_performance(db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP") -> PerformanceMetrics:
|
||||
portfolio = await get_portfolio(db, user_id, base_currency)
|
||||
total_return = portfolio.total_gain
|
||||
total_return_pct = portfolio.total_gain_pct
|
||||
return PerformanceMetrics(
|
||||
twrr=None, # full TWRR requires snapshot history — placeholder
|
||||
total_return=total_return,
|
||||
total_return_pct=total_return_pct,
|
||||
currency="GBP",
|
||||
currency=base_currency,
|
||||
)
|
||||
|
||||
|
||||
def _uk_tax_year(d: date) -> str:
|
||||
"""Return the UK tax year string for a given date (e.g. '2024/25')."""
|
||||
if d >= date(d.year, 4, 6):
|
||||
return f"{d.year}/{str(d.year + 1)[2:]}"
|
||||
return f"{d.year - 1}/{str(d.year)[2:]}"
|
||||
|
||||
|
||||
async def get_capital_gains(
|
||||
db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP"
|
||||
) -> CapitalGainsReport:
|
||||
"""
|
||||
Compute capital gains using the UK Section 104 pool method.
|
||||
Each asset's transactions are replayed chronologically; on each sell
|
||||
the cost of disposal is (sold_qty / pool_qty) * pool_cost.
|
||||
All values are converted to base_currency using current FX rates.
|
||||
"""
|
||||
holdings_result = await db.execute(
|
||||
select(InvestmentHolding).where(InvestmentHolding.user_id == user_id)
|
||||
)
|
||||
holdings = holdings_result.scalars().all()
|
||||
|
||||
# Pre-fetch assets and FX rates
|
||||
assets: dict[uuid.UUID, Asset] = {}
|
||||
holding_currencies: set[str] = set()
|
||||
for h in holdings:
|
||||
if h.asset_id not in assets:
|
||||
a = await _get_asset(db, h.asset_id)
|
||||
if a:
|
||||
assets[h.asset_id] = a
|
||||
holding_currencies.add(h.currency)
|
||||
|
||||
fx_rates: dict[tuple[str, str], Decimal] = {}
|
||||
for curr in holding_currencies:
|
||||
if curr != base_currency:
|
||||
fx_rates[(curr, base_currency)] = await _fetch_fx_rate(db, curr, base_currency)
|
||||
|
||||
disposals_by_year: dict[str, list[CapitalGainsDisposal]] = {}
|
||||
|
||||
for h in holdings:
|
||||
asset = assets.get(h.asset_id)
|
||||
if not asset:
|
||||
continue
|
||||
|
||||
txns_result = await db.execute(
|
||||
select(InvestmentTransaction)
|
||||
.where(InvestmentTransaction.holding_id == h.id)
|
||||
.order_by(InvestmentTransaction.date.asc(), InvestmentTransaction.created_at.asc())
|
||||
)
|
||||
txns = txns_result.scalars().all()
|
||||
|
||||
pool_qty = Decimal("0")
|
||||
pool_cost = Decimal("0") # in holding.currency
|
||||
|
||||
for txn in txns:
|
||||
if txn.type in ("buy", "transfer_in"):
|
||||
cost_of_purchase = txn.quantity * txn.price + txn.fees
|
||||
pool_qty += txn.quantity
|
||||
pool_cost += cost_of_purchase
|
||||
|
||||
elif txn.type in ("sell", "transfer_out") and pool_qty > 0:
|
||||
sell_qty = min(txn.quantity, pool_qty)
|
||||
cost_per_unit = pool_cost / pool_qty
|
||||
cost_of_disposal = cost_per_unit * sell_qty
|
||||
proceeds = txn.price * sell_qty - txn.fees
|
||||
|
||||
# Convert to base_currency
|
||||
to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1")
|
||||
proceeds_base = (proceeds * to_base).quantize(Decimal("0.01"))
|
||||
cost_base = (cost_of_disposal * to_base).quantize(Decimal("0.01"))
|
||||
gain_base = proceeds_base - cost_base
|
||||
|
||||
tax_year = _uk_tax_year(txn.date)
|
||||
disposals_by_year.setdefault(tax_year, []).append(
|
||||
CapitalGainsDisposal(
|
||||
date=txn.date,
|
||||
symbol=asset.symbol,
|
||||
asset_name=asset.name,
|
||||
quantity=sell_qty,
|
||||
proceeds=proceeds_base,
|
||||
cost=cost_base,
|
||||
gain=gain_base,
|
||||
currency=base_currency,
|
||||
)
|
||||
)
|
||||
|
||||
pool_qty -= sell_qty
|
||||
pool_cost -= cost_of_disposal
|
||||
if pool_qty <= 0:
|
||||
pool_qty = Decimal("0")
|
||||
pool_cost = Decimal("0")
|
||||
|
||||
elif txn.type == "split" and txn.price > 0:
|
||||
pool_qty = pool_qty * txn.quantity
|
||||
# pool_cost stays the same; avg cost per unit changes
|
||||
|
||||
tax_years: list[TaxYearSummary] = []
|
||||
for year_label in sorted(disposals_by_year.keys(), reverse=True):
|
||||
year_disposals = sorted(disposals_by_year[year_label], key=lambda d: d.date)
|
||||
total_proceeds = sum(d.proceeds for d in year_disposals)
|
||||
total_cost = sum(d.cost for d in year_disposals)
|
||||
total_gain = total_proceeds - total_cost
|
||||
tax_years.append(TaxYearSummary(
|
||||
tax_year=year_label,
|
||||
disposals=year_disposals,
|
||||
total_proceeds=total_proceeds,
|
||||
total_cost=total_cost,
|
||||
total_gain=total_gain,
|
||||
currency=base_currency,
|
||||
))
|
||||
|
||||
return CapitalGainsReport(tax_years=tax_years, currency=base_currency)
|
||||
|
||||
|
||||
async def get_or_create_asset(
|
||||
db: AsyncSession, symbol: str, name: str, asset_type: str,
|
||||
currency: str, data_source: str, data_source_id: str | None,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ from app.schemas.report import (
|
|||
IncomeExpenseReport,
|
||||
NetWorthPoint,
|
||||
NetWorthReport,
|
||||
SavingsRatePoint,
|
||||
SavingsRateReport,
|
||||
SpendingTrendPoint,
|
||||
SpendingTrendsReport,
|
||||
)
|
||||
|
|
@ -402,6 +404,50 @@ async def get_balance_sheet(
|
|||
)
|
||||
|
||||
|
||||
async def get_savings_rate_report(
|
||||
db: AsyncSession, user_id: uuid.UUID, months: int = 12
|
||||
) -> SavingsRateReport:
|
||||
cutoff = date.today().replace(day=1) - relativedelta(months=months - 1)
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
TO_CHAR(date, 'YYYY-MM') AS month,
|
||||
SUM(CASE WHEN type = 'income' THEN amount ELSE 0 END) AS income,
|
||||
SUM(CASE WHEN type = 'expense' THEN ABS(amount) ELSE 0 END) AS expenses
|
||||
FROM transactions
|
||||
WHERE user_id = CAST(:uid AS uuid)
|
||||
AND status != 'void'
|
||||
AND deleted_at IS NULL
|
||||
AND date >= :cutoff
|
||||
GROUP BY TO_CHAR(date, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
""").bindparams(uid=str(user_id), cutoff=cutoff)
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
points = []
|
||||
for row in rows:
|
||||
inc = Decimal(str(row.income or 0))
|
||||
exp = Decimal(str(row.expenses or 0))
|
||||
savings = inc - exp
|
||||
rate = (savings / inc * 100).quantize(Decimal("0.01")) if inc > 0 else Decimal("0")
|
||||
points.append(SavingsRatePoint(
|
||||
month=row.month,
|
||||
income=inc,
|
||||
expenses=exp,
|
||||
savings=savings,
|
||||
savings_rate=rate,
|
||||
))
|
||||
|
||||
n = len(points) or 1
|
||||
avg_rate = sum(p.savings_rate for p in points) / n
|
||||
return SavingsRateReport(
|
||||
points=points,
|
||||
avg_savings_rate=avg_rate.quantize(Decimal("0.01")),
|
||||
currency="GBP",
|
||||
)
|
||||
|
||||
|
||||
async def take_net_worth_snapshot(db: AsyncSession, user_id: uuid.UUID, base_currency: str) -> None:
|
||||
today = date.today()
|
||||
existing = await db.execute(
|
||||
|
|
|
|||
|
|
@ -144,6 +144,8 @@ async def list_transactions(
|
|||
conditions.append(Transaction.amount >= filters.min_amount)
|
||||
if filters.max_amount is not None:
|
||||
conditions.append(Transaction.amount <= filters.max_amount)
|
||||
if filters.is_recurring is not None:
|
||||
conditions.append(Transaction.is_recurring == filters.is_recurring)
|
||||
|
||||
query = select(Transaction).where(and_(*conditions)).order_by(Transaction.date.desc(), Transaction.created_at.desc())
|
||||
|
||||
|
|
|
|||
25
backend/app/workers/backup.py
Normal file
25
backend/app/workers/backup.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""
|
||||
Daily encrypted backup job — runs the backup.sh script inside the container.
|
||||
The script does: pg_dump | gzip | gpg symmetric AES-256 → /app/backups/
|
||||
"""
|
||||
import asyncio
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def backup_job() -> None:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"bash", "/app/scripts/backup.sh",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode().strip() if stdout else ""
|
||||
if proc.returncode == 0:
|
||||
logger.info("backup_complete", output=output)
|
||||
else:
|
||||
logger.error("backup_failed", returncode=proc.returncode, output=output)
|
||||
except Exception as exc:
|
||||
logger.error("backup_error", error=str(exc))
|
||||
|
|
@ -15,12 +15,13 @@ async def start_scheduler() -> None:
|
|||
from app.workers.snapshot import snapshot_job
|
||||
from app.workers.price_sync import price_sync_job
|
||||
from app.workers.fx_sync import fx_sync_job
|
||||
from app.workers.backup import backup_job
|
||||
_scheduler = AsyncIOScheduler()
|
||||
|
||||
_scheduler.add_job(snapshot_job, CronTrigger(hour=2, minute=0), id="nw_snapshot")
|
||||
_scheduler.add_job(price_sync_job, CronTrigger(minute="*/15"), id="price_sync")
|
||||
_scheduler.add_job(fx_sync_job, CronTrigger(minute=0), id="fx_sync")
|
||||
# _scheduler.add_job(backup_job, CronTrigger(hour=3), id="backup")
|
||||
_scheduler.add_job(backup_job, CronTrigger(hour=3, minute=0), id="backup")
|
||||
# _scheduler.add_job(ml_retrain_job, CronTrigger(day_of_week="sun", hour=1), id="ml_retrain")
|
||||
|
||||
_scheduler.start()
|
||||
|
|
|
|||
51
backend/scripts/backup.sh
Executable file
51
backend/scripts/backup.sh
Executable file
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
# backup.sh — pg_dump | gzip | gpg encrypt → /backups/
|
||||
# Run inside the backend container:
|
||||
# docker compose exec backend bash /app/scripts/backup.sh
|
||||
# Or from host:
|
||||
# docker compose exec -e BACKUP_PASSPHRASE="$BACKUP_PASSPHRASE" backend bash scripts/backup.sh
|
||||
set -euo pipefail
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
BACKUP_DIR="${BACKUP_DIR:-/app/backups}"
|
||||
BACKUP_FILE="${BACKUP_DIR}/${TIMESTAMP}.sql.gz.gpg"
|
||||
RETENTION_DAYS="${BACKUP_RETENTION_DAYS:-30}"
|
||||
|
||||
# Require passphrase
|
||||
if [ -z "${BACKUP_PASSPHRASE:-}" ]; then
|
||||
echo "[backup] ERROR: BACKUP_PASSPHRASE is not set" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "${BACKUP_DIR}"
|
||||
echo "[backup] Starting at ${TIMESTAMP}"
|
||||
|
||||
# GPG needs a writable home dir; appuser has no real home
|
||||
export GNUPGHOME=/tmp/.gnupg
|
||||
mkdir -p "${GNUPGHOME}"
|
||||
chmod 700 "${GNUPGHOME}"
|
||||
|
||||
# pg_dump using the DATABASE_URL but swap asyncpg driver for psycopg2-compatible URL
|
||||
PG_URL="${DATABASE_URL/postgresql+asyncpg/postgresql}"
|
||||
|
||||
pg_dump --clean --if-exists "${PG_URL}" \
|
||||
| gzip \
|
||||
| gpg --batch --yes --no-symkey-cache --pinentry-mode loopback \
|
||||
--symmetric --cipher-algo AES256 \
|
||||
--passphrase "${BACKUP_PASSPHRASE}" \
|
||||
--output "${BACKUP_FILE}"
|
||||
|
||||
SIZE=$(du -sh "${BACKUP_FILE}" | cut -f1)
|
||||
echo "[backup] Written ${SIZE} → ${BACKUP_FILE}"
|
||||
|
||||
# List current backups
|
||||
COUNT=$(find "${BACKUP_DIR}" -name "*.sql.gz.gpg" | wc -l)
|
||||
echo "[backup] ${COUNT} backup(s) on disk"
|
||||
|
||||
# Prune old backups
|
||||
PRUNED=$(find "${BACKUP_DIR}" -name "*.sql.gz.gpg" -mtime "+${RETENTION_DAYS}" -print -delete | wc -l)
|
||||
if [ "${PRUNED}" -gt 0 ]; then
|
||||
echo "[backup] Pruned ${PRUNED} backup(s) older than ${RETENTION_DAYS} days"
|
||||
fi
|
||||
|
||||
echo "[backup] Done"
|
||||
9
backend/scripts/entrypoint.sh
Normal file
9
backend/scripts/entrypoint.sh
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/sh
|
||||
set -e
|
||||
|
||||
# Fix ownership of bind-mounted directories so appuser can write to them.
|
||||
# This runs briefly as root before dropping privileges, which is the only
|
||||
# way to handle host directories that Docker creates as root.
|
||||
chown -R appuser:appuser /app/backups /app/uploads 2>/dev/null || true
|
||||
|
||||
exec gosu appuser "$@"
|
||||
56
backend/scripts/restore.sh
Executable file
56
backend/scripts/restore.sh
Executable file
|
|
@ -0,0 +1,56 @@
|
|||
#!/bin/bash
|
||||
# restore.sh — decrypt and restore a backup
|
||||
# Usage:
|
||||
# docker compose exec backend bash scripts/restore.sh <backup_file>
|
||||
# docker compose exec backend bash scripts/restore.sh --list
|
||||
set -euo pipefail
|
||||
|
||||
BACKUP_DIR="${BACKUP_DIR:-/app/backups}"
|
||||
|
||||
if [ "${1:-}" = "--list" ]; then
|
||||
echo "[restore] Available backups in ${BACKUP_DIR}:"
|
||||
find "${BACKUP_DIR}" -name "*.sql.gz.gpg" -printf " %f (%s bytes)\n" | sort -r
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -z "${1:-}" ]; then
|
||||
echo "Usage: restore.sh <backup_file.sql.gz.gpg>"
|
||||
echo " restore.sh --list"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BACKUP_FILE="$1"
|
||||
|
||||
# Accept bare filename or full path
|
||||
if [ ! -f "${BACKUP_FILE}" ] && [ -f "${BACKUP_DIR}/${BACKUP_FILE}" ]; then
|
||||
BACKUP_FILE="${BACKUP_DIR}/${BACKUP_FILE}"
|
||||
fi
|
||||
|
||||
if [ ! -f "${BACKUP_FILE}" ]; then
|
||||
echo "[restore] ERROR: File not found: ${BACKUP_FILE}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${BACKUP_PASSPHRASE:-}" ]; then
|
||||
echo "[restore] ERROR: BACKUP_PASSPHRASE is not set" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PG_URL="${DATABASE_URL/postgresql+asyncpg/postgresql}"
|
||||
|
||||
echo "[restore] WARNING: This will overwrite the current database."
|
||||
echo "[restore] File: ${BACKUP_FILE}"
|
||||
read -r -p "[restore] Type 'yes' to continue: " CONFIRM
|
||||
if [ "${CONFIRM}" != "yes" ]; then
|
||||
echo "[restore] Aborted"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[restore] Decrypting and restoring…"
|
||||
gpg --batch --yes --decrypt \
|
||||
--passphrase "${BACKUP_PASSPHRASE}" \
|
||||
"${BACKUP_FILE}" \
|
||||
| gunzip \
|
||||
| psql "${PG_URL}"
|
||||
|
||||
echo "[restore] Done"
|
||||
195
backend/scripts/rotate_keys.py
Normal file
195
backend/scripts/rotate_keys.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
AES-256-GCM encryption key rotation.
|
||||
|
||||
Re-encrypts all PII fields with a new key in a single atomic transaction.
|
||||
If anything fails the database is left unchanged.
|
||||
|
||||
Usage (from inside the backend container):
|
||||
python /app/scripts/rotate_keys.py --old-key <64-hex> --new-key <64-hex>
|
||||
|
||||
--dry-run Decrypt and re-encrypt in memory only, do not write to DB.
|
||||
|
||||
After the script reports success:
|
||||
1. Update ENCRYPTION_KEY in your .env file to the new key.
|
||||
2. Restart the backend: docker compose up -d backend
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from os import urandom
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crypto helpers (standalone, no dependency on app code)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_decrypt(key_hex: str):
|
||||
key = bytes.fromhex(key_hex)
|
||||
if len(key) != 32:
|
||||
sys.exit("ERROR: key must be 32 bytes (64 hex characters)")
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
def decrypt(data: bytes) -> str:
|
||||
if not data:
|
||||
return ""
|
||||
iv, ct = data[:12], data[12:]
|
||||
return aesgcm.decrypt(iv, ct, None).decode()
|
||||
|
||||
return decrypt
|
||||
|
||||
|
||||
def _make_encrypt(key_hex: str):
|
||||
key = bytes.fromhex(key_hex)
|
||||
if len(key) != 32:
|
||||
sys.exit("ERROR: key must be 32 bytes (64 hex characters)")
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
def encrypt(plaintext: str) -> bytes:
|
||||
if not plaintext:
|
||||
return b""
|
||||
iv = urandom(12)
|
||||
return iv + aesgcm.encrypt(iv, plaintext.encode(), None)
|
||||
|
||||
return encrypt
|
||||
|
||||
|
||||
def _rotate_bytes(data: bytes | None, decrypt, encrypt) -> bytes | None:
|
||||
"""Decrypt with old key, re-encrypt with new key. None/empty passes through."""
|
||||
if not data:
|
||||
return data
|
||||
return encrypt(decrypt(data))
|
||||
|
||||
|
||||
def _rotate_hex(hex_str: str | None, decrypt, encrypt) -> str | None:
|
||||
"""Same as _rotate_bytes but field is stored as hex text (TOTP secret)."""
|
||||
if not hex_str:
|
||||
return hex_str
|
||||
return encrypt(decrypt(bytes.fromhex(hex_str))).hex()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main rotation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def rotate(old_key: str, new_key: str, dry_run: bool) -> None:
|
||||
import asyncpg
|
||||
|
||||
decrypt = _make_decrypt(old_key)
|
||||
encrypt = _make_encrypt(new_key)
|
||||
|
||||
db_url = os.environ.get("DATABASE_URL", "")
|
||||
if not db_url:
|
||||
sys.exit("ERROR: DATABASE_URL environment variable not set")
|
||||
pg_url = db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
print(f"Connecting to database...")
|
||||
conn = await asyncpg.connect(pg_url)
|
||||
|
||||
try:
|
||||
async with conn.transaction():
|
||||
# ── accounts ──────────────────────────────────────────────────
|
||||
print("Rotating accounts (name, institution, notes)...")
|
||||
rows = await conn.fetch(
|
||||
"SELECT id, name, institution, notes FROM accounts"
|
||||
)
|
||||
updated = 0
|
||||
for row in rows:
|
||||
new_name = _rotate_bytes(row["name"], decrypt, encrypt)
|
||||
new_inst = _rotate_bytes(row["institution"], decrypt, encrypt)
|
||||
new_notes = _rotate_bytes(row["notes"], decrypt, encrypt)
|
||||
if not dry_run:
|
||||
await conn.execute(
|
||||
"UPDATE accounts SET name=$1, institution=$2, notes=$3 WHERE id=$4",
|
||||
new_name, new_inst, new_notes, row["id"],
|
||||
)
|
||||
updated += 1
|
||||
print(f" {'(dry-run) would rotate' if dry_run else 'rotated'} {updated} accounts")
|
||||
|
||||
# ── transactions ──────────────────────────────────────────────
|
||||
print("Rotating transactions (description, merchant, notes)...")
|
||||
rows = await conn.fetch(
|
||||
"SELECT id, description, merchant, notes FROM transactions"
|
||||
)
|
||||
updated = 0
|
||||
for row in rows:
|
||||
new_desc = _rotate_bytes(row["description"], decrypt, encrypt)
|
||||
new_merch = _rotate_bytes(row["merchant"], decrypt, encrypt)
|
||||
new_notes = _rotate_bytes(row["notes"], decrypt, encrypt)
|
||||
if not dry_run:
|
||||
await conn.execute(
|
||||
"UPDATE transactions SET description=$1, merchant=$2, notes=$3 WHERE id=$4",
|
||||
new_desc, new_merch, new_notes, row["id"],
|
||||
)
|
||||
updated += 1
|
||||
print(f" {'(dry-run) would rotate' if dry_run else 'rotated'} {updated} transactions")
|
||||
|
||||
# ── investment transactions ────────────────────────────────────
|
||||
print("Rotating investment transaction notes...")
|
||||
rows = await conn.fetch(
|
||||
"SELECT id, notes FROM investment_transactions WHERE notes IS NOT NULL"
|
||||
)
|
||||
updated = 0
|
||||
for row in rows:
|
||||
new_notes = _rotate_bytes(row["notes"], decrypt, encrypt)
|
||||
if not dry_run:
|
||||
await conn.execute(
|
||||
"UPDATE investment_transactions SET notes=$1 WHERE id=$2",
|
||||
new_notes, row["id"],
|
||||
)
|
||||
updated += 1
|
||||
print(f" {'(dry-run) would rotate' if dry_run else 'rotated'} {updated} investment transaction notes")
|
||||
|
||||
# ── users — TOTP secret (hex-encoded) ─────────────────────────
|
||||
print("Rotating user TOTP secrets...")
|
||||
rows = await conn.fetch(
|
||||
"SELECT id, totp_secret FROM users WHERE totp_secret IS NOT NULL"
|
||||
)
|
||||
updated = 0
|
||||
for row in rows:
|
||||
new_secret = _rotate_hex(row["totp_secret"], decrypt, encrypt)
|
||||
if not dry_run:
|
||||
await conn.execute(
|
||||
"UPDATE users SET totp_secret=$1 WHERE id=$2",
|
||||
new_secret, row["id"],
|
||||
)
|
||||
updated += 1
|
||||
print(f" {'(dry-run) would rotate' if dry_run else 'rotated'} {updated} TOTP secrets")
|
||||
|
||||
if dry_run:
|
||||
raise _DryRunAbort()
|
||||
|
||||
except _DryRunAbort:
|
||||
print("\nDry-run complete — no changes written.")
|
||||
return
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
print("\n✓ Key rotation complete.")
|
||||
print("\nNext steps:")
|
||||
print(" 1. Update ENCRYPTION_KEY in your .env to the new key.")
|
||||
print(" 2. Restart the backend: docker compose up -d backend")
|
||||
|
||||
|
||||
class _DryRunAbort(Exception):
|
||||
"""Raised inside the transaction to trigger an asyncpg rollback in dry-run mode."""
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Rotate AES-256-GCM encryption key")
|
||||
parser.add_argument("--old-key", required=True, help="Current key as 64-char hex")
|
||||
parser.add_argument("--new-key", required=True, help="New key as 64-char hex")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Validate only, do not write")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.old_key == args.new_key:
|
||||
sys.exit("ERROR: old key and new key are identical — nothing to do")
|
||||
|
||||
print(f"{'DRY RUN — ' if args.dry_run else ''}AES key rotation starting...")
|
||||
asyncio.run(rotate(args.old_key, args.new_key, args.dry_run))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
backend/scripts/rotate_keys.sh
Executable file
54
backend/scripts/rotate_keys.sh
Executable file
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# rotate_keys.sh — re-encrypt all AES-256-GCM fields with a new key.
|
||||
# The application must be STOPPED before running.
|
||||
#
|
||||
# Usage:
|
||||
# NEW_ENCRYPTION_KEY=$(python3 -c "import secrets; print(secrets.token_hex(32))")
|
||||
# docker compose stop backend
|
||||
# NEW_ENCRYPTION_KEY="$NEW_ENCRYPTION_KEY" ./scripts/rotate_keys.sh
|
||||
# # On success, update ENCRYPTION_KEY in .env, then:
|
||||
# docker compose up -d backend
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${NEW_ENCRYPTION_KEY:-}" ]; then
|
||||
echo "ERROR: NEW_ENCRYPTION_KEY is not set"
|
||||
echo "Generate with: python3 -c \"import secrets; print(secrets.token_hex(32))\""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${ENCRYPTION_KEY:-}" ]; then
|
||||
# Try to load from .env in the project root
|
||||
ENV_FILE="$(dirname "$0")/../.env"
|
||||
if [ -f "${ENV_FILE}" ]; then
|
||||
ENCRYPTION_KEY=$(grep -E '^ENCRYPTION_KEY=' "${ENV_FILE}" | cut -d= -f2- | tr -d '"' | tr -d "'")
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -z "${ENCRYPTION_KEY:-}" ]; then
|
||||
echo "ERROR: ENCRYPTION_KEY (current key) could not be found in environment or .env"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "${ENCRYPTION_KEY}" = "${NEW_ENCRYPTION_KEY}" ]; then
|
||||
echo "ERROR: NEW_ENCRYPTION_KEY is the same as the current key — nothing to do"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[rotate] This will re-encrypt ALL sensitive fields in the database."
|
||||
echo "[rotate] Ensure the application containers are stopped before continuing."
|
||||
read -r -p "[rotate] Type 'yes' to continue: " CONFIRM
|
||||
if [ "${CONFIRM}" != "yes" ]; then
|
||||
echo "[rotate] Aborted"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[rotate] Running key rotation inside the backend container…"
|
||||
docker compose exec \
|
||||
-e ENCRYPTION_KEY="${ENCRYPTION_KEY}" \
|
||||
-e NEW_ENCRYPTION_KEY="${NEW_ENCRYPTION_KEY}" \
|
||||
backend python -m app.core.key_rotation
|
||||
|
||||
echo ""
|
||||
echo "[rotate] SUCCESS. Next steps:"
|
||||
echo " 1. Update ENCRYPTION_KEY in .env to: ${NEW_ENCRYPTION_KEY}"
|
||||
echo " 2. Restart the backend: docker compose up -d backend"
|
||||
Loading…
Add table
Add a link
Reference in a new issue