""" UK tax service layer for MyMidas. Pure calculation functions live in tax_calculations.py (no DB imports). This module provides the DB-backed service layer: rate loading, CRUD, report builder. """ from __future__ import annotations import uuid from datetime import date, datetime, timezone from decimal import Decimal from typing import Any from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import decrypt_field, encrypt_field from app.services.tax_calculations import ( # noqa: F401 — re-exported for callers calculate_cgt, calculate_dividend_tax, calculate_income_tax, calculate_ni, parse_tax_code, tax_year_date_range, tax_year_for_date, ) # --------------------------------------------------------------------------- # DB helpers # --------------------------------------------------------------------------- async def load_rates(db: AsyncSession, user_id: uuid.UUID, tax_year: int) -> dict: """Load and return rate config dict for a given user/year. Returns {"income_tax": {...}, "ni": {...}, "cgt": {...}, "dividend": {...}} Raises ValueError if any rate type is missing for the requested year. """ from app.db.models.tax import TaxRateConfig result = await db.execute( select(TaxRateConfig).where( TaxRateConfig.user_id == user_id, TaxRateConfig.tax_year == tax_year, ) ) rows = list(result.scalars()) if not rows: raise ValueError(f"No tax rate config found for year {tax_year}. Please configure rates first.") rates: dict = {} for row in rows: rates[row.rate_type] = row.config required = {"income_tax", "ni", "cgt", "dividend"} missing = required - rates.keys() if missing: raise ValueError(f"Incomplete tax rate config for year {tax_year}: missing {missing}") return rates async def seed_default_rates(db: AsyncSession, user_id: uuid.UUID) -> None: """Insert default 2025 and 2026 rate configs for a newly registered user.""" from app.db.models.tax import TaxRateConfig now = datetime.now(timezone.utc) income_tax_bands = { "bands": [ {"from": 0, "to": 12570, "rate": 0.00}, {"from": 12570, "to": 50270, "rate": 0.20}, {"from": 50270, "to": 125140, "rate": 0.40}, {"from": 125140, "to": None, "rate": 0.45}, ] } ni_bands = { "bands": [ {"from": 0, "to": 12570, "rate": 0.00}, {"from": 12570, "to": 50270, "rate": 0.08}, {"from": 50270, "to": None, "rate": 0.02}, ] } cgt = {"exempt": 3000, "basic_rate": 0.18, "higher_rate": 0.24} dividend = { "allowance": 500, "basic_rate": 0.0875, "higher_rate": 0.3375, "additional_rate": 0.3935, } defaults = { "income_tax": income_tax_bands, "ni": ni_bands, "cgt": cgt, "dividend": dividend, } for tax_year in (2025, 2026): for rate_type, config in defaults.items(): existing = await db.execute( select(TaxRateConfig).where( TaxRateConfig.user_id == user_id, TaxRateConfig.tax_year == tax_year, TaxRateConfig.rate_type == rate_type, ) ) if existing.scalar_one_or_none() is None: db.add(TaxRateConfig( id=uuid.uuid4(), user_id=user_id, tax_year=tax_year, rate_type=rate_type, config=config, updated_at=now, )) # --------------------------------------------------------------------------- # Tax profile CRUD # --------------------------------------------------------------------------- async def get_tax_profile( db: AsyncSession, user_id: uuid.UUID, tax_year: int ): from app.db.models.tax import TaxProfile result = await db.execute( select(TaxProfile).where( TaxProfile.user_id == user_id, TaxProfile.tax_year == tax_year, ) ) return result.scalar_one_or_none() async def upsert_tax_profile( db: AsyncSession, user_id: uuid.UUID, tax_year: int, tax_code: str, employer_name: str | None, is_cumulative: bool, ): from app.db.models.tax import TaxProfile from app.core.audit import write_audit now = datetime.now(timezone.utc) profile = await get_tax_profile(db, user_id, tax_year) employer_enc = encrypt_field(employer_name) if employer_name else None if profile is None: profile = TaxProfile( id=uuid.uuid4(), user_id=user_id, tax_year=tax_year, tax_code=tax_code, employer_name_enc=employer_enc, is_cumulative=is_cumulative, created_at=now, updated_at=now, ) db.add(profile) action = "tax_profile_create" else: profile.tax_code = tax_code profile.employer_name_enc = employer_enc profile.is_cumulative = is_cumulative profile.updated_at = now action = "tax_profile_update" await db.flush() await write_audit(db, user_id=user_id, action=action, resource_type="tax_profile", resource_id=profile.id) return profile def _profile_to_response(profile) -> dict: from app.db.models.tax import TaxProfile employer = None if profile.employer_name_enc: try: employer = decrypt_field(profile.employer_name_enc) except Exception: employer = None return { "id": str(profile.id), "tax_year": profile.tax_year, "tax_code": profile.tax_code, "employer_name": employer, "is_cumulative": profile.is_cumulative, "created_at": profile.created_at.isoformat(), "updated_at": profile.updated_at.isoformat(), } # --------------------------------------------------------------------------- # Payslip CRUD # --------------------------------------------------------------------------- async def list_payslips(db: AsyncSession, user_id: uuid.UUID, tax_year: int) -> list: from app.db.models.tax import Payslip, TaxProfile result = await db.execute( select(Payslip) .join(TaxProfile, Payslip.tax_profile_id == TaxProfile.id) .where( Payslip.user_id == user_id, TaxProfile.tax_year == tax_year, ) .order_by(Payslip.period_year, Payslip.period_month.nulls_last()) ) return list(result.scalars()) async def create_payslip( db: AsyncSession, user_id: uuid.UUID, tax_year: int, period_month: int | None, period_year: int, gross_pay: Decimal, income_tax_withheld: Decimal, ni_withheld: Decimal, net_pay: Decimal, is_p60: bool = False, notes: str | None = None, ): from app.db.models.tax import Payslip from app.core.audit import write_audit profile = await get_tax_profile(db, user_id, tax_year) if profile is None: raise ValueError(f"No tax profile for year {tax_year}. Create a profile first.") now = datetime.now(timezone.utc) notes_enc = encrypt_field(notes) if notes else None payslip = Payslip( id=uuid.uuid4(), user_id=user_id, tax_profile_id=profile.id, period_month=period_month, period_year=period_year, gross_pay=gross_pay, income_tax_withheld=income_tax_withheld, ni_withheld=ni_withheld, net_pay=net_pay, is_p60=is_p60, notes_enc=notes_enc, created_at=now, ) db.add(payslip) await db.flush() await write_audit(db, user_id=user_id, action="payslip_create", resource_type="payslip", resource_id=payslip.id) return payslip async def update_payslip( db: AsyncSession, user_id: uuid.UUID, payslip_id: uuid.UUID, **kwargs, ): from app.db.models.tax import Payslip from app.core.audit import write_audit result = await db.execute( select(Payslip).where(Payslip.id == payslip_id, Payslip.user_id == user_id) ) payslip = result.scalar_one_or_none() if payslip is None: raise ValueError("Payslip not found") for field, value in kwargs.items(): if field == "notes": payslip.notes_enc = encrypt_field(value) if value else None else: setattr(payslip, field, value) await db.flush() await write_audit(db, user_id=user_id, action="payslip_update", resource_type="payslip", resource_id=payslip.id) return payslip async def delete_payslip(db: AsyncSession, user_id: uuid.UUID, payslip_id: uuid.UUID) -> None: from app.db.models.tax import Payslip from app.core.audit import write_audit result = await db.execute( select(Payslip).where(Payslip.id == payslip_id, Payslip.user_id == user_id) ) payslip = result.scalar_one_or_none() if payslip is None: raise ValueError("Payslip not found") await write_audit(db, user_id=user_id, action="payslip_delete", resource_type="payslip", resource_id=payslip_id) await db.delete(payslip) await db.flush() async def replace_with_p60( db: AsyncSession, user_id: uuid.UUID, tax_year: int, gross_pay: Decimal, income_tax_withheld: Decimal, ni_withheld: Decimal, net_pay: Decimal, ) -> None: """Delete all existing payslips for the tax year and replace with a single P60.""" from app.db.models.tax import Payslip, TaxProfile from app.core.audit import write_audit profile = await get_tax_profile(db, user_id, tax_year) if profile is None: raise ValueError(f"No tax profile for year {tax_year}. Create a profile first.") await db.execute( delete(Payslip).where( Payslip.user_id == user_id, Payslip.tax_profile_id == profile.id, ) ) now = datetime.now(timezone.utc) p60 = Payslip( id=uuid.uuid4(), user_id=user_id, tax_profile_id=profile.id, period_month=None, period_year=tax_year - 1, # P60 covers year ending 5 Apr tax_year; the employer year starts the prior calendar year gross_pay=gross_pay, income_tax_withheld=income_tax_withheld, ni_withheld=ni_withheld, net_pay=net_pay, is_p60=True, notes_enc=None, created_at=now, ) db.add(p60) await db.flush() await write_audit(db, user_id=user_id, action="payslip_p60_replace", resource_type="payslip", resource_id=p60.id) def _payslip_to_response(payslip) -> dict: notes = None if payslip.notes_enc: try: notes = decrypt_field(payslip.notes_enc) except Exception: notes = None return { "id": str(payslip.id), "tax_profile_id": str(payslip.tax_profile_id), "period_month": payslip.period_month, "period_year": payslip.period_year, "gross_pay": str(payslip.gross_pay), "income_tax_withheld": str(payslip.income_tax_withheld), "ni_withheld": str(payslip.ni_withheld), "net_pay": str(payslip.net_pay), "is_p60": payslip.is_p60, "notes": notes, "created_at": payslip.created_at.isoformat(), } # --------------------------------------------------------------------------- # Manual CGT disposal CRUD # --------------------------------------------------------------------------- async def list_manual_disposals( db: AsyncSession, user_id: uuid.UUID, tax_year: int ) -> list: from app.db.models.tax import ManualCGTDisposal result = await db.execute( select(ManualCGTDisposal).where( ManualCGTDisposal.user_id == user_id, ManualCGTDisposal.tax_year == tax_year, ).order_by(ManualCGTDisposal.disposal_date) ) return list(result.scalars()) async def create_manual_disposal( db: AsyncSession, user_id: uuid.UUID, tax_year: int, disposal_date: date, asset_description: str, proceeds: Decimal, cost_basis: Decimal, notes: str | None = None, ): from app.db.models.tax import ManualCGTDisposal from app.core.audit import write_audit now = datetime.now(timezone.utc) disposal = ManualCGTDisposal( id=uuid.uuid4(), user_id=user_id, tax_year=tax_year, disposal_date=disposal_date, asset_description_enc=encrypt_field(asset_description), proceeds=proceeds, cost_basis=cost_basis, notes_enc=encrypt_field(notes) if notes else None, created_at=now, ) db.add(disposal) await db.flush() await write_audit(db, user_id=user_id, action="cgt_disposal_create", resource_type="manual_cgt_disposal", resource_id=disposal.id) return disposal async def update_manual_disposal( db: AsyncSession, user_id: uuid.UUID, disposal_id: uuid.UUID, disposal_date: date, asset_description: str, proceeds: Decimal, cost_basis: Decimal, notes: str | None = None, ): from app.db.models.tax import ManualCGTDisposal from app.core.audit import write_audit result = await db.execute( select(ManualCGTDisposal).where( ManualCGTDisposal.id == disposal_id, ManualCGTDisposal.user_id == user_id, ) ) disposal = result.scalar_one_or_none() if disposal is None: raise ValueError("Disposal not found") disposal.disposal_date = disposal_date disposal.asset_description_enc = encrypt_field(asset_description) disposal.proceeds = proceeds disposal.cost_basis = cost_basis disposal.notes_enc = encrypt_field(notes) if notes else None await db.flush() await write_audit(db, user_id=user_id, action="cgt_disposal_update", resource_type="manual_cgt_disposal", resource_id=disposal.id) return disposal async def delete_manual_disposal( db: AsyncSession, user_id: uuid.UUID, disposal_id: uuid.UUID ) -> None: from app.db.models.tax import ManualCGTDisposal from app.core.audit import write_audit result = await db.execute( select(ManualCGTDisposal).where( ManualCGTDisposal.id == disposal_id, ManualCGTDisposal.user_id == user_id, ) ) disposal = result.scalar_one_or_none() if disposal is None: raise ValueError("Disposal not found") await write_audit(db, user_id=user_id, action="cgt_disposal_delete", resource_type="manual_cgt_disposal", resource_id=disposal_id) await db.delete(disposal) await db.flush() def _disposal_to_response(disposal) -> dict: asset_desc = "" try: asset_desc = decrypt_field(disposal.asset_description_enc) except Exception: pass notes = None if disposal.notes_enc: try: notes = decrypt_field(disposal.notes_enc) except Exception: pass gain_loss = disposal.proceeds - disposal.cost_basis return { "id": str(disposal.id), "tax_year": disposal.tax_year, "disposal_date": disposal.disposal_date.isoformat(), "asset_description": asset_desc, "proceeds": str(disposal.proceeds), "cost_basis": str(disposal.cost_basis), "gain_loss": str(gain_loss), "notes": notes, "created_at": disposal.created_at.isoformat(), } # --------------------------------------------------------------------------- # Tax rate config CRUD # --------------------------------------------------------------------------- async def list_configured_years(db: AsyncSession, user_id: uuid.UUID) -> list[int]: from app.db.models.tax import TaxRateConfig result = await db.execute( select(TaxRateConfig.tax_year) .where(TaxRateConfig.user_id == user_id) .distinct() .order_by(TaxRateConfig.tax_year) ) return [row[0] for row in result] async def get_rate_config( db: AsyncSession, user_id: uuid.UUID, tax_year: int ) -> dict: from app.db.models.tax import TaxRateConfig result = await db.execute( select(TaxRateConfig).where( TaxRateConfig.user_id == user_id, TaxRateConfig.tax_year == tax_year, ) ) rows = list(result.scalars()) if not rows: raise ValueError(f"No rate config for year {tax_year}") return { "tax_year": tax_year, "rates": {row.rate_type: row.config for row in rows}, "updated_at": max(row.updated_at for row in rows).isoformat(), } async def upsert_rate_config( db: AsyncSession, user_id: uuid.UUID, tax_year: int, rates: dict, ) -> dict: from app.db.models.tax import TaxRateConfig from app.core.audit import write_audit now = datetime.now(timezone.utc) for rate_type, config in rates.items(): result = await db.execute( select(TaxRateConfig).where( TaxRateConfig.user_id == user_id, TaxRateConfig.tax_year == tax_year, TaxRateConfig.rate_type == rate_type, ) ) row = result.scalar_one_or_none() if row is None: db.add(TaxRateConfig( id=uuid.uuid4(), user_id=user_id, tax_year=tax_year, rate_type=rate_type, config=config, updated_at=now, )) else: row.config = config row.updated_at = now await db.flush() await write_audit(db, user_id=user_id, action="tax_rate_config_update", resource_type="tax_rate_config", resource_id=None, metadata={"tax_year": tax_year}) return await get_rate_config(db, user_id, tax_year) # --------------------------------------------------------------------------- # Report builder # --------------------------------------------------------------------------- async def build_tax_report( db: AsyncSession, user_id: uuid.UUID, tax_year: int ) -> dict[str, Any]: """Build the full tax report for a given year. Steps: 1. Load rates 2. Load profile + payslip totals 3. Load investment sell disposals within the tax year 4. Load investment dividend transactions within the tax year 5. Load manual CGT disposals 6. Run calculations: income tax → NI → CGT → dividend tax 7. Return full report dict """ from app.db.models.tax import ManualCGTDisposal, Payslip, TaxProfile from app.db.models.investment_transaction import InvestmentTransaction from app.db.models.investment_holding import InvestmentHolding from app.db.models.asset import Asset rates = await load_rates(db, user_id, tax_year) start_date, end_date = tax_year_date_range(tax_year) # ---- Profile ---- profile = await get_tax_profile(db, user_id, tax_year) tax_code = profile.tax_code if profile else "1257L" profile_data = _profile_to_response(profile) if profile else None # ---- Payslip totals ---- payslips = await list_payslips(db, user_id, tax_year) gross_income = sum((Decimal(str(p.gross_pay)) for p in payslips), Decimal("0")) income_tax_withheld = sum((Decimal(str(p.income_tax_withheld)) for p in payslips), Decimal("0")) ni_withheld = sum((Decimal(str(p.ni_withheld)) for p in payslips), Decimal("0")) payslip_rows = [_payslip_to_response(p) for p in payslips] # ---- Investment sell disposals ---- inv_disposals_result = await db.execute( select(InvestmentTransaction, InvestmentHolding, Asset) .join(InvestmentHolding, InvestmentTransaction.holding_id == InvestmentHolding.id) .join(Asset, InvestmentHolding.asset_id == Asset.id) .where( InvestmentTransaction.user_id == user_id, InvestmentTransaction.type == "sell", InvestmentTransaction.date >= start_date, InvestmentTransaction.date <= end_date, ) .order_by(InvestmentTransaction.date) ) inv_disposal_rows = [] total_inv_gain = Decimal("0") for inv_txn, holding, asset in inv_disposals_result: proceeds = Decimal(str(inv_txn.total_amount)) cost = Decimal(str(inv_txn.quantity)) * Decimal(str(holding.avg_cost_basis)) gain = proceeds - cost - Decimal(str(inv_txn.fees)) total_inv_gain += gain inv_disposal_rows.append({ "date": inv_txn.date.isoformat(), "asset": asset.name, "symbol": asset.symbol, "quantity": str(inv_txn.quantity), "proceeds": str(proceeds), "cost_basis": str(cost), "fees": str(inv_txn.fees), "gain_loss": str(gain), }) # ---- Manual CGT disposals ---- manual_disposals = await list_manual_disposals(db, user_id, tax_year) manual_disposal_rows = [_disposal_to_response(d) for d in manual_disposals] total_manual_gain = sum( (Decimal(str(d.proceeds)) - Decimal(str(d.cost_basis)) for d in manual_disposals), Decimal("0"), ) total_cgt_gain = total_inv_gain + total_manual_gain # ---- Investment dividends ---- div_result = await db.execute( select(InvestmentTransaction, InvestmentHolding, Asset) .join(InvestmentHolding, InvestmentTransaction.holding_id == InvestmentHolding.id) .join(Asset, InvestmentHolding.asset_id == Asset.id) .where( InvestmentTransaction.user_id == user_id, InvestmentTransaction.type == "dividend", InvestmentTransaction.date >= start_date, InvestmentTransaction.date <= end_date, ) .order_by(InvestmentTransaction.date) ) dividend_rows = [] total_dividends = Decimal("0") for inv_txn, holding, asset in div_result: amount = Decimal(str(inv_txn.total_amount)) total_dividends += amount dividend_rows.append({ "date": inv_txn.date.isoformat(), "asset": asset.name, "symbol": asset.symbol, "amount": str(amount), }) # ---- Calculations ---- income_tax_result = calculate_income_tax(gross_income, tax_code, rates) ni_result = calculate_ni(gross_income, rates) cgt_result = calculate_cgt( total_cgt_gain, income_tax_result["remaining_basic_rate_band"], rates, ) dividend_result = calculate_dividend_tax( total_dividends, income_tax_result["remaining_basic_rate_band"], rates, ) # ---- Totals ---- total_liability = ( income_tax_result["liability"] + ni_result["liability"] + cgt_result["liability"] + dividend_result["liability"] ) total_withheld = income_tax_withheld + ni_withheld net_owed = total_liability - total_withheld return { "tax_year": tax_year, "tax_year_display": f"{tax_year - 1}/{str(tax_year)[2:]}", "profile": profile_data, "income": { "gross_income": str(gross_income), "income_tax_withheld": str(income_tax_withheld), "ni_withheld": str(ni_withheld), "payslips": payslip_rows, }, "income_tax": { **{k: str(v) if isinstance(v, Decimal) else v for k, v in income_tax_result.items() if k != "remaining_basic_rate_band"}, "withheld": str(income_tax_withheld), "owed": str(income_tax_result["liability"] - income_tax_withheld), }, "ni": { **{k: str(v) if isinstance(v, Decimal) else v for k, v in ni_result.items()}, "withheld": str(ni_withheld), "owed": str(ni_result["liability"] - ni_withheld), }, "cgt": { **{k: str(v) if isinstance(v, Decimal) else v for k, v in cgt_result.items()}, "investment_disposals": inv_disposal_rows, "manual_disposals": manual_disposal_rows, "total_gain": str(total_cgt_gain), }, "dividends": { **{k: str(v) if isinstance(v, Decimal) else v for k, v in dividend_result.items()}, "dividend_transactions": dividend_rows, }, "summary": { "total_liability": str(total_liability), "total_withheld": str(total_withheld), "net_owed": str(net_owed), "overpaid": net_owed < 0, }, }