""" Pure UK tax calculation functions. Zero external dependencies — no DB, no ORM. Each function receives a pre-loaded `rates` dict so they are fully unit-testable. Tax year convention: tax_year=N means 6 Apr (N-1) → 5 Apr N. """ from __future__ import annotations import re from datetime import date from decimal import ROUND_HALF_UP, Decimal from typing import Any # --------------------------------------------------------------------------- # Tax year helpers # --------------------------------------------------------------------------- def tax_year_for_date(d: date) -> int: """Return tax_year int for a calendar date. tax_year=N = 6 Apr (N-1) → 5 Apr N.""" if (d.month, d.day) >= (4, 6): return d.year + 1 return d.year def tax_year_date_range(tax_year: int) -> tuple[date, date]: """Return (start_date, end_date) inclusive for the given tax year.""" return date(tax_year - 1, 4, 6), date(tax_year, 4, 5) # --------------------------------------------------------------------------- # Tax code parser # --------------------------------------------------------------------------- def parse_tax_code(code: str) -> dict[str, Any]: """Parse a UK tax code string. Returns: allowance — annual personal allowance in £ (negative for K codes) rate_override — flat rate (0.0–1.0) if code fixes a single rate, else None k_code — True if K prefix (negative allowance) no_tax — True if NT code """ raw = code.strip().upper() raw = re.sub(r"[/\s]?(W1|M1)$", "", raw) if raw == "NT": return {"allowance": Decimal("0"), "rate_override": Decimal("0"), "k_code": False, "no_tax": True} if raw == "BR": return {"allowance": Decimal("0"), "rate_override": Decimal("0.20"), "k_code": False, "no_tax": False} if raw == "D0": return {"allowance": Decimal("0"), "rate_override": Decimal("0.40"), "k_code": False, "no_tax": False} if raw == "D1": return {"allowance": Decimal("0"), "rate_override": Decimal("0.45"), "k_code": False, "no_tax": False} if raw == "0T": return {"allowance": Decimal("0"), "rate_override": None, "k_code": False, "no_tax": False} k_match = re.fullmatch(r"K(\d+)", raw) if k_match: return {"allowance": -Decimal(k_match.group(1)) * 10, "rate_override": None, "k_code": True, "no_tax": False} std_match = re.fullmatch(r"(\d+)[LMNTY]?", raw) if std_match: return {"allowance": Decimal(std_match.group(1)) * 10, "rate_override": None, "k_code": False, "no_tax": False} # Unknown code — treat as 0T return {"allowance": Decimal("0"), "rate_override": None, "k_code": False, "no_tax": False} # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _apply_bands(amount: Decimal, bands: list[dict]) -> tuple[Decimal, list[dict]]: total = Decimal("0") breakdown = [] for band in bands: band_from = Decimal(str(band["from"])) band_to = Decimal(str(band["to"])) if band["to"] is not None else None rate = Decimal(str(band["rate"])) if amount <= band_from: break upper = min(amount, band_to) if band_to is not None else amount taxable_in_band = upper - band_from tax_in_band = (taxable_in_band * rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) total += tax_in_band if taxable_in_band > 0: breakdown.append({ "from": int(band_from), "to": int(band_to) if band_to is not None else None, "rate": float(rate), "taxable": float(taxable_in_band), "tax": float(tax_in_band), }) return total, breakdown def _personal_allowance_tapered(base_allowance: Decimal, gross_income: Decimal) -> Decimal: """Reduce PA by £1 per £2 over £100,000; floor at zero at £125,140.""" taper_threshold = Decimal("100000") if gross_income <= taper_threshold: return base_allowance reduction = ((gross_income - taper_threshold) / 2).quantize(Decimal("1"), rounding=ROUND_HALF_UP) return max(Decimal("0"), base_allowance - reduction) # --------------------------------------------------------------------------- # Core calculation functions # --------------------------------------------------------------------------- def calculate_income_tax( gross_income: Decimal, tax_code: str, rates: dict, ) -> dict[str, Any]: """Calculate income tax liability and remaining basic-rate band. Bands are applied to GROSS income. The 0% band threshold is adjusted to match the actual personal allowance from the tax code (with taper applied if applicable). K codes add their amount to gross income before applying the standard bands. Returns: personal_allowance, taxable_income, liability, band_breakdown, remaining_basic_rate_band (passed downstream to CGT/dividend calculations) """ parsed = parse_tax_code(tax_code) bands = rates["income_tax"]["bands"] # These are gross-income thresholds from the band definitions pa_threshold = Decimal(str(next(b["to"] for b in bands if b["rate"] == 0.00))) basic_rate_upper = Decimal(str(next(b["to"] for b in bands if b["rate"] == 0.20))) if parsed["no_tax"]: return { "personal_allowance": pa_threshold, "taxable_income": Decimal("0"), "liability": Decimal("0"), "band_breakdown": [], "remaining_basic_rate_band": max(Decimal("0"), basic_rate_upper - gross_income), } if parsed["rate_override"] is not None: liability = (gross_income * parsed["rate_override"]).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) return { "personal_allowance": Decimal("0"), "taxable_income": gross_income, "liability": liability, "band_breakdown": [{"rate": float(parsed["rate_override"]), "tax": float(liability)}], "remaining_basic_rate_band": Decimal("0"), } if parsed["k_code"]: # K codes: the K amount adds to taxable base; standard PA band still applies to effective gross. # effective_gross is the "notional income" HMRC uses to calculate tax. k_amount = abs(parsed["allowance"]) effective_gross = gross_income + k_amount personal_allowance = Decimal("0") # K code replaces any standard PA grant taxable_income = effective_gross # reported as the effective taxable base liability, band_breakdown = _apply_bands(effective_gross, bands) remaining_brb = max(Decimal("0"), basic_rate_upper - effective_gross) else: base_pa = parsed["allowance"] personal_allowance = _personal_allowance_tapered(base_pa, gross_income) if base_pa > 0 else base_pa # Adjust the 0% band to match the actual personal allowance, then apply to gross income. adjusted_bands = [ {"from": 0, "to": float(personal_allowance), "rate": 0.00} if b["rate"] == 0.00 else b for b in bands ] taxable_income = max(Decimal("0"), gross_income - personal_allowance) liability, band_breakdown = _apply_bands(gross_income, adjusted_bands) remaining_brb = max(Decimal("0"), basic_rate_upper - gross_income) return { "personal_allowance": personal_allowance, "taxable_income": taxable_income, "liability": liability, "band_breakdown": band_breakdown, "remaining_basic_rate_band": remaining_brb, } def calculate_ni(gross_income: Decimal, rates: dict) -> dict[str, Any]: """Calculate primary Class 1 NI liability.""" bands = rates["ni"]["bands"] liability, band_breakdown = _apply_bands(gross_income, bands) return {"liability": liability, "band_breakdown": band_breakdown} def calculate_cgt( total_gain: Decimal, remaining_basic_rate_band: Decimal, rates: dict, ) -> dict[str, Any]: """Calculate CGT liability. Gains within the remaining basic-rate band are taxed at basic_rate; gains above it at higher_rate. Annual exempt amount applied first. """ cgt_rates = rates["cgt"] exempt = Decimal(str(cgt_rates["exempt"])) basic_rate = Decimal(str(cgt_rates["basic_rate"])) higher_rate = Decimal(str(cgt_rates["higher_rate"])) taxable_gain = max(Decimal("0"), total_gain - exempt) if taxable_gain == 0: return { "gross_gain": total_gain, "exempt": min(total_gain, exempt) if total_gain > 0 else Decimal("0"), "taxable_gain": Decimal("0"), "liability": Decimal("0"), "band_breakdown": [], } basic_portion = min(taxable_gain, remaining_basic_rate_band) higher_portion = taxable_gain - basic_portion liability = ( (basic_portion * basic_rate) + (higher_portion * higher_rate) ).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) breakdown = [] if basic_portion > 0: breakdown.append({"rate": float(basic_rate), "taxable": float(basic_portion), "tax": float((basic_portion * basic_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))}) if higher_portion > 0: breakdown.append({"rate": float(higher_rate), "taxable": float(higher_portion), "tax": float((higher_portion * higher_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))}) return { "gross_gain": total_gain, "exempt": min(total_gain, exempt), "taxable_gain": taxable_gain, "liability": liability, "band_breakdown": breakdown, } def calculate_dividend_tax( total_dividends: Decimal, remaining_basic_rate_band: Decimal, rates: dict, ) -> dict[str, Any]: """Calculate dividend tax liability. Dividend allowance applied first; taxable dividends are then slotted into the remaining income bands to determine which rate applies. """ div_rates = rates["dividend"] allowance = Decimal(str(div_rates["allowance"])) basic_rate = Decimal(str(div_rates["basic_rate"])) higher_rate = Decimal(str(div_rates["higher_rate"])) additional_rate = Decimal(str(div_rates["additional_rate"])) taxable_dividends = max(Decimal("0"), total_dividends - allowance) if taxable_dividends == 0: return { "gross_dividends": total_dividends, "allowance": min(total_dividends, allowance), "taxable_dividends": Decimal("0"), "liability": Decimal("0"), "band_breakdown": [], } basic_portion = min(taxable_dividends, remaining_basic_rate_band) remainder = taxable_dividends - basic_portion higher_upper = Decimal(str( next(b["to"] for b in rates["income_tax"]["bands"] if b["rate"] == 0.40) )) higher_portion = min(remainder, higher_upper) additional_portion = remainder - higher_portion liability = ( (basic_portion * basic_rate) + (higher_portion * higher_rate) + (additional_portion * additional_rate) ).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) breakdown = [] if basic_portion > 0: breakdown.append({"rate": float(basic_rate), "taxable": float(basic_portion), "tax": float((basic_portion * basic_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))}) if higher_portion > 0: breakdown.append({"rate": float(higher_rate), "taxable": float(higher_portion), "tax": float((higher_portion * higher_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))}) if additional_portion > 0: breakdown.append({"rate": float(additional_rate), "taxable": float(additional_portion), "tax": float((additional_portion * additional_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))}) return { "gross_dividends": total_dividends, "allowance": min(total_dividends, allowance), "taxable_dividends": taxable_dividends, "liability": liability, "band_breakdown": breakdown, }