Add pensions module and integrate with tax report

Adds a full pensions feature: SIPP/workplace DC/LISA account metadata,
contribution recording with relief-at-source/net-pay/salary-sacrifice
gross calculations, state pension tracker, annual allowance monitor,
and LISA summary. Pension contributions feed into the tax report
(RAS gross totals, allowance used). Includes two Alembic migrations,
backend service/schema/API, and full frontend pensions page with
cards for allowance, state pension, LISA, and retirement projection.

Also fixes CSRF cookie secure flag (must be false for HTTP deployments)
and extends tax schemas/service to expose pension data in the report.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
megaproxy 2026-04-28 09:59:01 +00:00
parent b30e8e577b
commit 1a2c8efd01
30 changed files with 3537 additions and 8 deletions

View file

@ -0,0 +1,664 @@
from __future__ import annotations
import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING
from app.db.models.account import Account
from sqlalchemy import func, select
from app.core.security import decrypt_field, encrypt_field
from app.db.models.pension import PensionContribution, PensionMetadata, StatePensionRecord
from app.schemas.pension import (
AllowanceSummary,
CarryForwardYear,
ChartDataPoint,
LisaSummary,
LisaTaxYearBreakdown,
PensionContributionCreate,
PensionContributionUpdate,
PensionMetadataCreate,
PensionMetadataUpdate,
ProjectionScenario,
RetirementProjection,
StatePensionCreate,
YtdSummary,
_FULL_SP_WEEKLY,
_SP_AGE,
)
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
def _uk_tax_year(d: date) -> int:
"""Return the tax year ending on 5 Apr that contains date d.
E.g. 2026-01-01 2026; 2025-04-05 2025; 2025-04-06 2026."""
return d.year + 1 if (d.month > 4 or (d.month == 4 and d.day >= 6)) else d.year
def _compute_gross_and_relief(
member_amount: Decimal,
relief_type: str,
) -> tuple[Decimal, Decimal]:
"""Return (gross_amount, relief_amount) given a net member contribution and relief type."""
if relief_type == "relief_at_source":
gross = member_amount / Decimal("0.8")
relief = gross - member_amount
elif relief_type in ("net_pay", "salary_sacrifice"):
gross = member_amount
relief = Decimal("0")
else: # none
gross = member_amount
relief = Decimal("0")
return gross.quantize(Decimal("0.01")), relief.quantize(Decimal("0.01"))
def _enc(v: str | None) -> bytes | None:
return encrypt_field(v) if v else None
def _dec(v: bytes | None) -> str | None:
return decrypt_field(v) if v else None
# ---------------------------------------------------------------------------
# Pension metadata (one per pension account)
# ---------------------------------------------------------------------------
async def get_pension_metadata(
db: "AsyncSession",
account_id: uuid.UUID,
user_id: uuid.UUID,
) -> PensionMetadata | None:
result = await db.execute(
select(PensionMetadata).where(
PensionMetadata.account_id == account_id,
PensionMetadata.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def upsert_pension_metadata(
db: "AsyncSession",
account_id: uuid.UUID,
user_id: uuid.UUID,
data: PensionMetadataCreate,
) -> PensionMetadata:
existing = await get_pension_metadata(db, account_id, user_id)
now = datetime.now(timezone.utc)
if existing:
existing.pension_type = data.pension_type
existing.provider_name_enc = _enc(data.provider_name)
existing.scheme_name_enc = _enc(data.scheme_name)
existing.member_reference_enc = _enc(data.member_reference)
existing.dob = data.dob
existing.target_retirement_age = data.target_retirement_age
existing.assumed_growth_rate = data.assumed_growth_rate
existing.updated_at = now
return existing
meta = PensionMetadata(
id=uuid.uuid4(),
user_id=user_id,
account_id=account_id,
pension_type=data.pension_type,
provider_name_enc=_enc(data.provider_name),
scheme_name_enc=_enc(data.scheme_name),
member_reference_enc=_enc(data.member_reference),
dob=data.dob,
target_retirement_age=data.target_retirement_age,
assumed_growth_rate=data.assumed_growth_rate,
created_at=now,
updated_at=now,
)
db.add(meta)
return meta
async def update_pension_metadata(
db: "AsyncSession",
account_id: uuid.UUID,
user_id: uuid.UUID,
data: PensionMetadataUpdate,
) -> PensionMetadata | None:
meta = await get_pension_metadata(db, account_id, user_id)
if not meta:
return None
now = datetime.now(timezone.utc)
if data.pension_type is not None:
meta.pension_type = data.pension_type
if data.provider_name is not None:
meta.provider_name_enc = _enc(data.provider_name)
if data.scheme_name is not None:
meta.scheme_name_enc = _enc(data.scheme_name)
if data.member_reference is not None:
meta.member_reference_enc = _enc(data.member_reference)
if data.dob is not None:
meta.dob = data.dob
if data.target_retirement_age is not None:
meta.target_retirement_age = data.target_retirement_age
if data.assumed_growth_rate is not None:
meta.assumed_growth_rate = data.assumed_growth_rate
meta.updated_at = now
return meta
def decode_metadata(meta: PensionMetadata) -> dict:
return {
"id": meta.id,
"account_id": meta.account_id,
"pension_type": meta.pension_type,
"provider_name": _dec(meta.provider_name_enc),
"scheme_name": _dec(meta.scheme_name_enc),
"member_reference": _dec(meta.member_reference_enc),
"dob": meta.dob,
"target_retirement_age": meta.target_retirement_age,
"assumed_growth_rate": meta.assumed_growth_rate,
"created_at": meta.created_at,
"updated_at": meta.updated_at,
}
# ---------------------------------------------------------------------------
# Contributions
# ---------------------------------------------------------------------------
async def list_contributions(
db: "AsyncSession",
pension_id: uuid.UUID,
user_id: uuid.UUID,
tax_year: int | None = None,
) -> list[dict]:
q = select(PensionContribution).where(
PensionContribution.pension_id == pension_id,
PensionContribution.user_id == user_id,
)
if tax_year is not None:
q = q.where(PensionContribution.tax_year == tax_year)
q = q.order_by(PensionContribution.contribution_date.desc())
result = await db.execute(q)
rows = result.scalars().all()
return [_decode_contribution(r) for r in rows]
async def add_contribution(
db: "AsyncSession",
pension_id: uuid.UUID,
user_id: uuid.UUID,
data: PensionContributionCreate,
) -> dict:
gross, relief = _compute_gross_and_relief(data.member_amount, data.relief_type)
tax_year = _uk_tax_year(data.contribution_date)
contrib = PensionContribution(
id=uuid.uuid4(),
user_id=user_id,
pension_id=pension_id,
contribution_date=data.contribution_date,
tax_year=tax_year,
member_amount=data.member_amount,
employer_amount=data.employer_amount,
relief_type=data.relief_type,
gross_amount=gross,
relief_amount=relief,
notes_enc=_enc(data.notes),
created_at=datetime.now(timezone.utc),
)
db.add(contrib)
return _decode_contribution(contrib)
async def update_contribution(
db: "AsyncSession",
contribution_id: uuid.UUID,
user_id: uuid.UUID,
data: PensionContributionUpdate,
) -> dict | None:
result = await db.execute(
select(PensionContribution).where(
PensionContribution.id == contribution_id,
PensionContribution.user_id == user_id,
)
)
contrib = result.scalar_one_or_none()
if not contrib:
return None
if data.contribution_date is not None:
contrib.contribution_date = data.contribution_date
contrib.tax_year = _uk_tax_year(data.contribution_date)
if data.member_amount is not None:
contrib.member_amount = data.member_amount
if data.employer_amount is not None:
contrib.employer_amount = data.employer_amount
if data.relief_type is not None:
contrib.relief_type = data.relief_type
if data.notes is not None:
contrib.notes_enc = _enc(data.notes)
gross, relief = _compute_gross_and_relief(contrib.member_amount, contrib.relief_type)
contrib.gross_amount = gross
contrib.relief_amount = relief
return _decode_contribution(contrib)
async def delete_contribution(
db: "AsyncSession",
contribution_id: uuid.UUID,
user_id: uuid.UUID,
) -> bool:
result = await db.execute(
select(PensionContribution).where(
PensionContribution.id == contribution_id,
PensionContribution.user_id == user_id,
)
)
contrib = result.scalar_one_or_none()
if not contrib:
return False
await db.delete(contrib)
return True
def _decode_contribution(c: PensionContribution) -> dict:
return {
"id": c.id,
"pension_id": c.pension_id,
"contribution_date": c.contribution_date,
"tax_year": c.tax_year,
"member_amount": c.member_amount,
"employer_amount": c.employer_amount,
"relief_type": c.relief_type,
"gross_amount": c.gross_amount,
"relief_amount": c.relief_amount,
"notes": _dec(c.notes_enc),
"created_at": c.created_at,
}
# ---------------------------------------------------------------------------
# YTD / summary
# ---------------------------------------------------------------------------
async def get_ytd_summary(
db: "AsyncSession",
pension_id: uuid.UUID,
user_id: uuid.UUID,
tax_year: int,
) -> YtdSummary:
result = await db.execute(
select(
func.count(PensionContribution.id).label("count"),
func.coalesce(func.sum(PensionContribution.member_amount), 0).label("member_total"),
func.coalesce(func.sum(PensionContribution.employer_amount), 0).label("employer_total"),
func.coalesce(func.sum(PensionContribution.gross_amount), 0).label("gross_total"),
func.coalesce(func.sum(PensionContribution.relief_amount), 0).label("relief_total"),
).where(
PensionContribution.pension_id == pension_id,
PensionContribution.user_id == user_id,
PensionContribution.tax_year == tax_year,
)
)
row = result.one()
return YtdSummary(
tax_year=tax_year,
member_total=Decimal(str(row.member_total)),
employer_total=Decimal(str(row.employer_total)),
gross_total=Decimal(str(row.gross_total)),
relief_total=Decimal(str(row.relief_total)),
contribution_count=row.count,
)
async def get_all_pensions_ytd_summary(
db: "AsyncSession",
user_id: uuid.UUID,
tax_year: int,
) -> YtdSummary:
"""Aggregate YTD totals across all pension accounts for the summary header."""
result = await db.execute(
select(
func.count(PensionContribution.id).label("count"),
func.coalesce(func.sum(PensionContribution.member_amount), 0).label("member_total"),
func.coalesce(func.sum(PensionContribution.employer_amount), 0).label("employer_total"),
func.coalesce(func.sum(PensionContribution.gross_amount), 0).label("gross_total"),
func.coalesce(func.sum(PensionContribution.relief_amount), 0).label("relief_total"),
).where(
PensionContribution.user_id == user_id,
PensionContribution.tax_year == tax_year,
)
)
row = result.one()
return YtdSummary(
tax_year=tax_year,
member_total=Decimal(str(row.member_total)),
employer_total=Decimal(str(row.employer_total)),
gross_total=Decimal(str(row.gross_total)),
relief_total=Decimal(str(row.relief_total)),
contribution_count=row.count,
)
# ---------------------------------------------------------------------------
# Annual allowance
# ---------------------------------------------------------------------------
_STANDARD_ALLOWANCE = Decimal("60000")
async def get_allowance_summary(
db: "AsyncSession",
user_id: uuid.UUID,
tax_year: int,
) -> AllowanceSummary:
"""
Returns annual allowance usage for tax_year plus carry-forward from
the prior 3 tax years (oldest-first, as HMRC requires).
"""
years_to_fetch = [tax_year - 3, tax_year - 2, tax_year - 1, tax_year]
result = await db.execute(
select(
PensionContribution.tax_year,
PensionContribution.relief_type,
func.coalesce(func.sum(PensionContribution.gross_amount), 0).label("gross_total"),
func.coalesce(func.sum(PensionContribution.relief_amount), 0).label("relief_total"),
).where(
PensionContribution.user_id == user_id,
PensionContribution.tax_year.in_(years_to_fetch),
).group_by(
PensionContribution.tax_year,
PensionContribution.relief_type,
)
)
rows = result.all()
year_gross: dict[int, Decimal] = {y: Decimal("0") for y in years_to_fetch}
year_ras_relief: dict[int, Decimal] = {y: Decimal("0") for y in years_to_fetch}
for row in rows:
year_gross[row.tax_year] = year_gross.get(row.tax_year, Decimal("0")) + Decimal(str(row.gross_total))
if row.relief_type == "relief_at_source":
year_ras_relief[row.tax_year] = year_ras_relief.get(row.tax_year, Decimal("0")) + Decimal(str(row.relief_total))
carry_forward: list[CarryForwardYear] = []
carry_forward_total = Decimal("0")
for yr in [tax_year - 3, tax_year - 2, tax_year - 1]:
contribs = year_gross[yr]
unused = max(Decimal("0"), _STANDARD_ALLOWANCE - contribs)
carry_forward.append(CarryForwardYear(
tax_year=yr,
standard_allowance=_STANDARD_ALLOWANCE,
contributions=contribs,
unused=unused,
))
carry_forward_total += unused
contributions_total = year_gross[tax_year]
remaining = max(Decimal("0"), _STANDARD_ALLOWANCE - contributions_total)
total_available = _STANDARD_ALLOWANCE + carry_forward_total
ras_relief = year_ras_relief[tax_year]
ras_gross_current = sum(
(Decimal(str(r.gross_total)) for r in rows if r.tax_year == tax_year and r.relief_type == "relief_at_source"),
Decimal("0"),
)
higher_rate_claimable = (ras_gross_current * Decimal("0.20")).quantize(Decimal("0.01"))
additional_rate_claimable = (ras_gross_current * Decimal("0.05")).quantize(Decimal("0.01"))
return AllowanceSummary(
tax_year=tax_year,
standard_allowance=_STANDARD_ALLOWANCE,
contributions_total=contributions_total,
remaining=remaining,
carry_forward=carry_forward,
carry_forward_total=carry_forward_total,
total_available=total_available,
relief_ras_total=ras_relief,
relief_higher_rate_claimable=higher_rate_claimable,
relief_additional_rate_claimable=additional_rate_claimable,
)
# ---------------------------------------------------------------------------
# State Pension
# ---------------------------------------------------------------------------
def _sp_amounts(qualifying_years: int) -> tuple[Decimal, Decimal]:
weekly = min(_FULL_SP_WEEKLY, (_FULL_SP_WEEKLY * qualifying_years / 35).quantize(Decimal("0.01")))
return weekly, (weekly * 52).quantize(Decimal("0.01"))
async def get_state_pension(
db: "AsyncSession",
user_id: uuid.UUID,
) -> StatePensionRecord | None:
result = await db.execute(
select(StatePensionRecord).where(StatePensionRecord.user_id == user_id)
)
return result.scalar_one_or_none()
async def upsert_state_pension(
db: "AsyncSession",
user_id: uuid.UUID,
data: StatePensionCreate,
) -> dict:
now = datetime.now(timezone.utc)
existing = await get_state_pension(db, user_id)
if existing:
existing.qualifying_years = data.qualifying_years
existing.checked_date = data.checked_date
existing.updated_at = now
record = existing
else:
record = StatePensionRecord(
id=uuid.uuid4(),
user_id=user_id,
qualifying_years=data.qualifying_years,
checked_date=data.checked_date,
created_at=now,
updated_at=now,
)
db.add(record)
return _decode_sp(record)
def _decode_sp(record: StatePensionRecord) -> dict:
weekly, annual = _sp_amounts(record.qualifying_years)
years_to_full = max(0, 35 - record.qualifying_years)
return {
"id": record.id,
"qualifying_years": record.qualifying_years,
"checked_date": record.checked_date,
"weekly_amount": weekly,
"annual_amount": annual,
"is_full_pension": record.qualifying_years >= 35,
"years_to_full": years_to_full,
"state_pension_age": _SP_AGE,
}
# ---------------------------------------------------------------------------
# Retirement projection
# ---------------------------------------------------------------------------
_SCENARIOS = [
("Conservative 2%", Decimal("0.02")),
("Moderate 5%", Decimal("0.05")),
("Growth 8%", Decimal("0.08")),
]
async def get_retirement_projection(
db: "AsyncSession",
user_id: uuid.UUID,
account_id: uuid.UUID,
) -> RetirementProjection:
acc_result = await db.execute(
select(Account).where(Account.id == account_id, Account.user_id == user_id, Account.deleted_at.is_(None))
)
account = acc_result.scalar_one_or_none()
if not account:
raise ValueError("Account not found")
meta = await get_pension_metadata(db, account_id, user_id)
if not meta or not meta.dob or not meta.target_retirement_age:
raise ValueError("Set date of birth and target retirement age in pension details first")
today = date.today()
age_today = (today - meta.dob).days / 365.25
years_to_retirement = max(0, int(meta.target_retirement_age - age_today))
current_balance = Decimal(str(account.current_balance))
# Build year-by-year chart data
chart_data: list[ChartDataPoint] = []
pots = {r: current_balance for _, r in _SCENARIOS}
current_year = today.year
for yr in range(years_to_retirement + 1):
chart_data.append(ChartDataPoint(
year=current_year + yr,
pot_2pct=pots[Decimal("0.02")].quantize(Decimal("1")),
pot_5pct=pots[Decimal("0.05")].quantize(Decimal("1")),
pot_8pct=pots[Decimal("0.08")].quantize(Decimal("1")),
))
for _, rate in _SCENARIOS:
pots[rate] = pots[rate] * (1 + rate)
# Final pot values at retirement (last chart point)
final = chart_data[-1] if chart_data else None
scenario_pots = {
Decimal("0.02"): Decimal(str(final.pot_2pct)) if final else current_balance,
Decimal("0.05"): Decimal(str(final.pot_5pct)) if final else current_balance,
Decimal("0.08"): Decimal(str(final.pot_8pct)) if final else current_balance,
}
scenarios = [
ProjectionScenario(
label=label,
growth_rate=rate,
projected_pot=scenario_pots[rate],
annual_drawdown_4pct=(scenario_pots[rate] * Decimal("0.04")).quantize(Decimal("1")),
annual_drawdown_3pct=(scenario_pots[rate] * Decimal("0.03")).quantize(Decimal("1")),
)
for label, rate in _SCENARIOS
]
# State pension
sp_record = await get_state_pension(db, user_id)
sp_annual = None
if sp_record:
_, sp_annual = _sp_amounts(sp_record.qualifying_years)
from app.core.security import decrypt_field
acc_name = decrypt_field(account.name_enc) if account.name_enc else ""
return RetirementProjection(
account_id=account_id,
account_name=acc_name,
current_balance=current_balance,
years_to_retirement=years_to_retirement,
target_retirement_age=meta.target_retirement_age,
scenarios=scenarios,
state_pension_annual=sp_annual,
state_pension_age=_SP_AGE,
chart_data=chart_data,
)
# ---------------------------------------------------------------------------
# LISA summary
# ---------------------------------------------------------------------------
_LISA_ANNUAL_LIMIT = Decimal("4000")
_LISA_BONUS_RATE = Decimal("0.25")
_LISA_WITHDRAWAL_PENALTY = Decimal("0.25")
async def get_lisa_summary(
db: "AsyncSession",
user_id: uuid.UUID,
account_id: uuid.UUID,
) -> LisaSummary:
# Verify account is LISA type
meta_result = await db.execute(
select(PensionMetadata).where(
PensionMetadata.account_id == account_id,
PensionMetadata.user_id == user_id,
)
)
meta = meta_result.scalar_one_or_none()
if meta is None or meta.pension_type != "lisa":
raise ValueError("Account is not a LISA")
# Load account for balance and created_at
acc_result = await db.execute(
select(Account).where(Account.id == account_id, Account.user_id == user_id)
)
account = acc_result.scalar_one_or_none()
if account is None:
raise ValueError("Account not found")
current_balance = Decimal(str(account.current_balance or 0))
# Load all contributions
contrib_result = await db.execute(
select(PensionContribution).where(
PensionContribution.pension_id == meta.id,
PensionContribution.user_id == user_id,
)
)
contribs = contrib_result.scalars().all()
# Group by tax year
by_year: dict[int, Decimal] = {}
for c in contribs:
by_year[c.tax_year] = by_year.get(c.tax_year, Decimal("0")) + Decimal(str(c.member_amount))
current_year = _uk_tax_year(date.today())
breakdown: list[LisaTaxYearBreakdown] = []
for ty in sorted(by_year):
contributions = by_year[ty].quantize(Decimal("0.01"))
capped = min(contributions, _LISA_ANNUAL_LIMIT)
bonus = (capped * _LISA_BONUS_RATE).quantize(Decimal("0.01"))
limit_remaining = max(Decimal("0"), _LISA_ANNUAL_LIMIT - contributions).quantize(Decimal("0.01"))
limit_used_pct = min(Decimal("100"), (contributions / _LISA_ANNUAL_LIMIT * 100)).quantize(Decimal("0.01"))
breakdown.append(LisaTaxYearBreakdown(
tax_year=ty,
contributions=contributions,
bonus_expected=bonus,
limit_remaining=limit_remaining,
limit_used_pct=limit_used_pct,
))
current_contributions = by_year.get(current_year, Decimal("0")).quantize(Decimal("0.01"))
current_capped = min(current_contributions, _LISA_ANNUAL_LIMIT)
current_bonus = (current_capped * _LISA_BONUS_RATE).quantize(Decimal("0.01"))
current_limit_remaining = max(Decimal("0"), _LISA_ANNUAL_LIMIT - current_contributions).quantize(Decimal("0.01"))
total_contributions = sum(by_year.values(), Decimal("0")).quantize(Decimal("0.01"))
total_capped = sum(min(v, _LISA_ANNUAL_LIMIT) for v in by_year.values())
total_bonus = (total_capped * _LISA_BONUS_RATE).quantize(Decimal("0.01"))
# Withdrawal penalty: 25% of current balance (which includes any bonus already paid)
penalty_amount = (current_balance * _LISA_WITHDRAWAL_PENALTY).quantize(Decimal("0.01"))
penalty_pct = _LISA_WITHDRAWAL_PENALTY * 100
return LisaSummary(
account_id=account_id,
account_name=decrypt_field(account.name_enc) if account.name_enc else "",
tax_year_breakdown=breakdown,
current_year_contributions=current_contributions,
current_year_bonus_expected=current_bonus,
current_year_limit_remaining=current_limit_remaining,
total_contributions=total_contributions,
total_bonus_expected=total_bonus,
account_opened_date=account.created_at,
withdrawal_penalty_amount=penalty_amount,
withdrawal_penalty_pct=penalty_pct,
penalty_warning=True,
)