from __future__ import annotations import uuid from datetime import date, datetime, timezone from decimal import Decimal from typing import TYPE_CHECKING from app.db.models.account import Account from sqlalchemy import func, select from app.core.security import decrypt_field, encrypt_field from app.db.models.pension import PensionContribution, PensionMetadata, StatePensionRecord from app.schemas.pension import ( AllowanceSummary, CarryForwardYear, ChartDataPoint, LisaSummary, LisaTaxYearBreakdown, PensionContributionCreate, PensionContributionUpdate, PensionMetadataCreate, PensionMetadataUpdate, ProjectionScenario, RetirementProjection, StatePensionCreate, YtdSummary, _FULL_SP_WEEKLY, _SP_AGE, ) if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession def _uk_tax_year(d: date) -> int: """Return the tax year ending on 5 Apr that contains date d. E.g. 2026-01-01 → 2026; 2025-04-05 → 2025; 2025-04-06 → 2026.""" return d.year + 1 if (d.month > 4 or (d.month == 4 and d.day >= 6)) else d.year def _compute_gross_and_relief( member_amount: Decimal, relief_type: str, ) -> tuple[Decimal, Decimal]: """Return (gross_amount, relief_amount) given a net member contribution and relief type.""" if relief_type == "relief_at_source": gross = member_amount / Decimal("0.8") relief = gross - member_amount elif relief_type in ("net_pay", "salary_sacrifice"): gross = member_amount relief = Decimal("0") else: # none gross = member_amount relief = Decimal("0") return gross.quantize(Decimal("0.01")), relief.quantize(Decimal("0.01")) def _enc(v: str | None) -> bytes | None: return encrypt_field(v) if v else None def _dec(v: bytes | None) -> str | None: return decrypt_field(v) if v else None # --------------------------------------------------------------------------- # Pension metadata (one per pension account) # --------------------------------------------------------------------------- async def get_pension_metadata( db: "AsyncSession", account_id: uuid.UUID, user_id: uuid.UUID, ) -> PensionMetadata | None: result = await db.execute( select(PensionMetadata).where( PensionMetadata.account_id == account_id, PensionMetadata.user_id == user_id, ) ) return result.scalar_one_or_none() async def upsert_pension_metadata( db: "AsyncSession", account_id: uuid.UUID, user_id: uuid.UUID, data: PensionMetadataCreate, ) -> PensionMetadata: existing = await get_pension_metadata(db, account_id, user_id) now = datetime.now(timezone.utc) if existing: existing.pension_type = data.pension_type existing.provider_name_enc = _enc(data.provider_name) existing.scheme_name_enc = _enc(data.scheme_name) existing.member_reference_enc = _enc(data.member_reference) existing.dob = data.dob existing.target_retirement_age = data.target_retirement_age existing.assumed_growth_rate = data.assumed_growth_rate existing.updated_at = now return existing meta = PensionMetadata( id=uuid.uuid4(), user_id=user_id, account_id=account_id, pension_type=data.pension_type, provider_name_enc=_enc(data.provider_name), scheme_name_enc=_enc(data.scheme_name), member_reference_enc=_enc(data.member_reference), dob=data.dob, target_retirement_age=data.target_retirement_age, assumed_growth_rate=data.assumed_growth_rate, created_at=now, updated_at=now, ) db.add(meta) return meta async def update_pension_metadata( db: "AsyncSession", account_id: uuid.UUID, user_id: uuid.UUID, data: PensionMetadataUpdate, ) -> PensionMetadata | None: meta = await get_pension_metadata(db, account_id, user_id) if not meta: return None now = datetime.now(timezone.utc) if data.pension_type is not None: meta.pension_type = data.pension_type if data.provider_name is not None: meta.provider_name_enc = _enc(data.provider_name) if data.scheme_name is not None: meta.scheme_name_enc = _enc(data.scheme_name) if data.member_reference is not None: meta.member_reference_enc = _enc(data.member_reference) if data.dob is not None: meta.dob = data.dob if data.target_retirement_age is not None: meta.target_retirement_age = data.target_retirement_age if data.assumed_growth_rate is not None: meta.assumed_growth_rate = data.assumed_growth_rate meta.updated_at = now return meta def decode_metadata(meta: PensionMetadata) -> dict: return { "id": meta.id, "account_id": meta.account_id, "pension_type": meta.pension_type, "provider_name": _dec(meta.provider_name_enc), "scheme_name": _dec(meta.scheme_name_enc), "member_reference": _dec(meta.member_reference_enc), "dob": meta.dob, "target_retirement_age": meta.target_retirement_age, "assumed_growth_rate": meta.assumed_growth_rate, "created_at": meta.created_at, "updated_at": meta.updated_at, } # --------------------------------------------------------------------------- # Contributions # --------------------------------------------------------------------------- async def list_contributions( db: "AsyncSession", pension_id: uuid.UUID, user_id: uuid.UUID, tax_year: int | None = None, ) -> list[dict]: q = select(PensionContribution).where( PensionContribution.pension_id == pension_id, PensionContribution.user_id == user_id, ) if tax_year is not None: q = q.where(PensionContribution.tax_year == tax_year) q = q.order_by(PensionContribution.contribution_date.desc()) result = await db.execute(q) rows = result.scalars().all() return [_decode_contribution(r) for r in rows] async def add_contribution( db: "AsyncSession", pension_id: uuid.UUID, user_id: uuid.UUID, data: PensionContributionCreate, ) -> dict: gross, relief = _compute_gross_and_relief(data.member_amount, data.relief_type) tax_year = _uk_tax_year(data.contribution_date) contrib = PensionContribution( id=uuid.uuid4(), user_id=user_id, pension_id=pension_id, contribution_date=data.contribution_date, tax_year=tax_year, member_amount=data.member_amount, employer_amount=data.employer_amount, relief_type=data.relief_type, gross_amount=gross, relief_amount=relief, notes_enc=_enc(data.notes), created_at=datetime.now(timezone.utc), ) db.add(contrib) return _decode_contribution(contrib) async def update_contribution( db: "AsyncSession", contribution_id: uuid.UUID, user_id: uuid.UUID, data: PensionContributionUpdate, ) -> dict | None: result = await db.execute( select(PensionContribution).where( PensionContribution.id == contribution_id, PensionContribution.user_id == user_id, ) ) contrib = result.scalar_one_or_none() if not contrib: return None if data.contribution_date is not None: contrib.contribution_date = data.contribution_date contrib.tax_year = _uk_tax_year(data.contribution_date) if data.member_amount is not None: contrib.member_amount = data.member_amount if data.employer_amount is not None: contrib.employer_amount = data.employer_amount if data.relief_type is not None: contrib.relief_type = data.relief_type if data.notes is not None: contrib.notes_enc = _enc(data.notes) gross, relief = _compute_gross_and_relief(contrib.member_amount, contrib.relief_type) contrib.gross_amount = gross contrib.relief_amount = relief return _decode_contribution(contrib) async def delete_contribution( db: "AsyncSession", contribution_id: uuid.UUID, user_id: uuid.UUID, ) -> bool: result = await db.execute( select(PensionContribution).where( PensionContribution.id == contribution_id, PensionContribution.user_id == user_id, ) ) contrib = result.scalar_one_or_none() if not contrib: return False await db.delete(contrib) return True def _decode_contribution(c: PensionContribution) -> dict: return { "id": c.id, "pension_id": c.pension_id, "contribution_date": c.contribution_date, "tax_year": c.tax_year, "member_amount": c.member_amount, "employer_amount": c.employer_amount, "relief_type": c.relief_type, "gross_amount": c.gross_amount, "relief_amount": c.relief_amount, "notes": _dec(c.notes_enc), "created_at": c.created_at, } # --------------------------------------------------------------------------- # YTD / summary # --------------------------------------------------------------------------- async def get_ytd_summary( db: "AsyncSession", pension_id: uuid.UUID, user_id: uuid.UUID, tax_year: int, ) -> YtdSummary: result = await db.execute( select( func.count(PensionContribution.id).label("count"), func.coalesce(func.sum(PensionContribution.member_amount), 0).label("member_total"), func.coalesce(func.sum(PensionContribution.employer_amount), 0).label("employer_total"), func.coalesce(func.sum(PensionContribution.gross_amount), 0).label("gross_total"), func.coalesce(func.sum(PensionContribution.relief_amount), 0).label("relief_total"), ).where( PensionContribution.pension_id == pension_id, PensionContribution.user_id == user_id, PensionContribution.tax_year == tax_year, ) ) row = result.one() return YtdSummary( tax_year=tax_year, member_total=Decimal(str(row.member_total)), employer_total=Decimal(str(row.employer_total)), gross_total=Decimal(str(row.gross_total)), relief_total=Decimal(str(row.relief_total)), contribution_count=row.count, ) async def get_all_pensions_ytd_summary( db: "AsyncSession", user_id: uuid.UUID, tax_year: int, ) -> YtdSummary: """Aggregate YTD totals across all pension accounts for the summary header.""" result = await db.execute( select( func.count(PensionContribution.id).label("count"), func.coalesce(func.sum(PensionContribution.member_amount), 0).label("member_total"), func.coalesce(func.sum(PensionContribution.employer_amount), 0).label("employer_total"), func.coalesce(func.sum(PensionContribution.gross_amount), 0).label("gross_total"), func.coalesce(func.sum(PensionContribution.relief_amount), 0).label("relief_total"), ).where( PensionContribution.user_id == user_id, PensionContribution.tax_year == tax_year, ) ) row = result.one() return YtdSummary( tax_year=tax_year, member_total=Decimal(str(row.member_total)), employer_total=Decimal(str(row.employer_total)), gross_total=Decimal(str(row.gross_total)), relief_total=Decimal(str(row.relief_total)), contribution_count=row.count, ) # --------------------------------------------------------------------------- # Annual allowance # --------------------------------------------------------------------------- _STANDARD_ALLOWANCE = Decimal("60000") async def get_allowance_summary( db: "AsyncSession", user_id: uuid.UUID, tax_year: int, ) -> AllowanceSummary: """ Returns annual allowance usage for tax_year plus carry-forward from the prior 3 tax years (oldest-first, as HMRC requires). """ years_to_fetch = [tax_year - 3, tax_year - 2, tax_year - 1, tax_year] result = await db.execute( select( PensionContribution.tax_year, PensionContribution.relief_type, func.coalesce(func.sum(PensionContribution.gross_amount), 0).label("gross_total"), func.coalesce(func.sum(PensionContribution.relief_amount), 0).label("relief_total"), ).where( PensionContribution.user_id == user_id, PensionContribution.tax_year.in_(years_to_fetch), ).group_by( PensionContribution.tax_year, PensionContribution.relief_type, ) ) rows = result.all() year_gross: dict[int, Decimal] = {y: Decimal("0") for y in years_to_fetch} year_ras_relief: dict[int, Decimal] = {y: Decimal("0") for y in years_to_fetch} for row in rows: year_gross[row.tax_year] = year_gross.get(row.tax_year, Decimal("0")) + Decimal(str(row.gross_total)) if row.relief_type == "relief_at_source": year_ras_relief[row.tax_year] = year_ras_relief.get(row.tax_year, Decimal("0")) + Decimal(str(row.relief_total)) carry_forward: list[CarryForwardYear] = [] carry_forward_total = Decimal("0") for yr in [tax_year - 3, tax_year - 2, tax_year - 1]: contribs = year_gross[yr] unused = max(Decimal("0"), _STANDARD_ALLOWANCE - contribs) carry_forward.append(CarryForwardYear( tax_year=yr, standard_allowance=_STANDARD_ALLOWANCE, contributions=contribs, unused=unused, )) carry_forward_total += unused contributions_total = year_gross[tax_year] remaining = max(Decimal("0"), _STANDARD_ALLOWANCE - contributions_total) total_available = _STANDARD_ALLOWANCE + carry_forward_total ras_relief = year_ras_relief[tax_year] ras_gross_current = sum( (Decimal(str(r.gross_total)) for r in rows if r.tax_year == tax_year and r.relief_type == "relief_at_source"), Decimal("0"), ) higher_rate_claimable = (ras_gross_current * Decimal("0.20")).quantize(Decimal("0.01")) additional_rate_claimable = (ras_gross_current * Decimal("0.05")).quantize(Decimal("0.01")) return AllowanceSummary( tax_year=tax_year, standard_allowance=_STANDARD_ALLOWANCE, contributions_total=contributions_total, remaining=remaining, carry_forward=carry_forward, carry_forward_total=carry_forward_total, total_available=total_available, relief_ras_total=ras_relief, relief_higher_rate_claimable=higher_rate_claimable, relief_additional_rate_claimable=additional_rate_claimable, ) # --------------------------------------------------------------------------- # State Pension # --------------------------------------------------------------------------- def _sp_amounts(qualifying_years: int) -> tuple[Decimal, Decimal]: weekly = min(_FULL_SP_WEEKLY, (_FULL_SP_WEEKLY * qualifying_years / 35).quantize(Decimal("0.01"))) return weekly, (weekly * 52).quantize(Decimal("0.01")) async def get_state_pension( db: "AsyncSession", user_id: uuid.UUID, ) -> StatePensionRecord | None: result = await db.execute( select(StatePensionRecord).where(StatePensionRecord.user_id == user_id) ) return result.scalar_one_or_none() async def upsert_state_pension( db: "AsyncSession", user_id: uuid.UUID, data: StatePensionCreate, ) -> dict: now = datetime.now(timezone.utc) existing = await get_state_pension(db, user_id) if existing: existing.qualifying_years = data.qualifying_years existing.checked_date = data.checked_date existing.updated_at = now record = existing else: record = StatePensionRecord( id=uuid.uuid4(), user_id=user_id, qualifying_years=data.qualifying_years, checked_date=data.checked_date, created_at=now, updated_at=now, ) db.add(record) return _decode_sp(record) def _decode_sp(record: StatePensionRecord) -> dict: weekly, annual = _sp_amounts(record.qualifying_years) years_to_full = max(0, 35 - record.qualifying_years) return { "id": record.id, "qualifying_years": record.qualifying_years, "checked_date": record.checked_date, "weekly_amount": weekly, "annual_amount": annual, "is_full_pension": record.qualifying_years >= 35, "years_to_full": years_to_full, "state_pension_age": _SP_AGE, } # --------------------------------------------------------------------------- # Retirement projection # --------------------------------------------------------------------------- _SCENARIOS = [ ("Conservative 2%", Decimal("0.02")), ("Moderate 5%", Decimal("0.05")), ("Growth 8%", Decimal("0.08")), ] async def get_retirement_projection( db: "AsyncSession", user_id: uuid.UUID, account_id: uuid.UUID, ) -> RetirementProjection: acc_result = await db.execute( select(Account).where(Account.id == account_id, Account.user_id == user_id, Account.deleted_at.is_(None)) ) account = acc_result.scalar_one_or_none() if not account: raise ValueError("Account not found") meta = await get_pension_metadata(db, account_id, user_id) if not meta or not meta.dob or not meta.target_retirement_age: raise ValueError("Set date of birth and target retirement age in pension details first") today = date.today() age_today = (today - meta.dob).days / 365.25 years_to_retirement = max(0, int(meta.target_retirement_age - age_today)) current_balance = Decimal(str(account.current_balance)) # Build year-by-year chart data chart_data: list[ChartDataPoint] = [] pots = {r: current_balance for _, r in _SCENARIOS} current_year = today.year for yr in range(years_to_retirement + 1): chart_data.append(ChartDataPoint( year=current_year + yr, pot_2pct=pots[Decimal("0.02")].quantize(Decimal("1")), pot_5pct=pots[Decimal("0.05")].quantize(Decimal("1")), pot_8pct=pots[Decimal("0.08")].quantize(Decimal("1")), )) for _, rate in _SCENARIOS: pots[rate] = pots[rate] * (1 + rate) # Final pot values at retirement (last chart point) final = chart_data[-1] if chart_data else None scenario_pots = { Decimal("0.02"): Decimal(str(final.pot_2pct)) if final else current_balance, Decimal("0.05"): Decimal(str(final.pot_5pct)) if final else current_balance, Decimal("0.08"): Decimal(str(final.pot_8pct)) if final else current_balance, } scenarios = [ ProjectionScenario( label=label, growth_rate=rate, projected_pot=scenario_pots[rate], annual_drawdown_4pct=(scenario_pots[rate] * Decimal("0.04")).quantize(Decimal("1")), annual_drawdown_3pct=(scenario_pots[rate] * Decimal("0.03")).quantize(Decimal("1")), ) for label, rate in _SCENARIOS ] # State pension sp_record = await get_state_pension(db, user_id) sp_annual = None if sp_record: _, sp_annual = _sp_amounts(sp_record.qualifying_years) from app.core.security import decrypt_field acc_name = decrypt_field(account.name_enc) if account.name_enc else "" return RetirementProjection( account_id=account_id, account_name=acc_name, current_balance=current_balance, years_to_retirement=years_to_retirement, target_retirement_age=meta.target_retirement_age, scenarios=scenarios, state_pension_annual=sp_annual, state_pension_age=_SP_AGE, chart_data=chart_data, ) # --------------------------------------------------------------------------- # LISA summary # --------------------------------------------------------------------------- _LISA_ANNUAL_LIMIT = Decimal("4000") _LISA_BONUS_RATE = Decimal("0.25") _LISA_WITHDRAWAL_PENALTY = Decimal("0.25") async def get_lisa_summary( db: "AsyncSession", user_id: uuid.UUID, account_id: uuid.UUID, ) -> LisaSummary: # Verify account is LISA type meta_result = await db.execute( select(PensionMetadata).where( PensionMetadata.account_id == account_id, PensionMetadata.user_id == user_id, ) ) meta = meta_result.scalar_one_or_none() if meta is None or meta.pension_type != "lisa": raise ValueError("Account is not a LISA") # Load account for balance and created_at acc_result = await db.execute( select(Account).where(Account.id == account_id, Account.user_id == user_id) ) account = acc_result.scalar_one_or_none() if account is None: raise ValueError("Account not found") current_balance = Decimal(str(account.current_balance or 0)) # Load all contributions contrib_result = await db.execute( select(PensionContribution).where( PensionContribution.pension_id == meta.id, PensionContribution.user_id == user_id, ) ) contribs = contrib_result.scalars().all() # Group by tax year by_year: dict[int, Decimal] = {} for c in contribs: by_year[c.tax_year] = by_year.get(c.tax_year, Decimal("0")) + Decimal(str(c.member_amount)) current_year = _uk_tax_year(date.today()) breakdown: list[LisaTaxYearBreakdown] = [] for ty in sorted(by_year): contributions = by_year[ty].quantize(Decimal("0.01")) capped = min(contributions, _LISA_ANNUAL_LIMIT) bonus = (capped * _LISA_BONUS_RATE).quantize(Decimal("0.01")) limit_remaining = max(Decimal("0"), _LISA_ANNUAL_LIMIT - contributions).quantize(Decimal("0.01")) limit_used_pct = min(Decimal("100"), (contributions / _LISA_ANNUAL_LIMIT * 100)).quantize(Decimal("0.01")) breakdown.append(LisaTaxYearBreakdown( tax_year=ty, contributions=contributions, bonus_expected=bonus, limit_remaining=limit_remaining, limit_used_pct=limit_used_pct, )) current_contributions = by_year.get(current_year, Decimal("0")).quantize(Decimal("0.01")) current_capped = min(current_contributions, _LISA_ANNUAL_LIMIT) current_bonus = (current_capped * _LISA_BONUS_RATE).quantize(Decimal("0.01")) current_limit_remaining = max(Decimal("0"), _LISA_ANNUAL_LIMIT - current_contributions).quantize(Decimal("0.01")) total_contributions = sum(by_year.values(), Decimal("0")).quantize(Decimal("0.01")) total_capped = sum(min(v, _LISA_ANNUAL_LIMIT) for v in by_year.values()) total_bonus = (total_capped * _LISA_BONUS_RATE).quantize(Decimal("0.01")) # Withdrawal penalty: 25% of current balance (which includes any bonus already paid) penalty_amount = (current_balance * _LISA_WITHDRAWAL_PENALTY).quantize(Decimal("0.01")) penalty_pct = _LISA_WITHDRAWAL_PENALTY * 100 return LisaSummary( account_id=account_id, account_name=decrypt_field(account.name_enc) if account.name_enc else "", tax_year_breakdown=breakdown, current_year_contributions=current_contributions, current_year_bonus_expected=current_bonus, current_year_limit_remaining=current_limit_remaining, total_contributions=total_contributions, total_bonus_expected=total_bonus, account_opened_date=account.created_at, withdrawal_penalty_amount=penalty_amount, withdrawal_penalty_pct=penalty_pct, penalty_warning=True, )