Add recurring transaction detection, subscriptions page, and UK tax reporting

- Recurring service: auto-detects direct debits/subscriptions from CSV imports
  using frequency analysis; manual toggle in transaction detail drawer
- Subscriptions page (/subscriptions): groups recurring payments with monthly
  cost equivalents, next-payment badges, and re-scan trigger
- UK Tax page (/tax): payslips/P60 entry, income tax + NI + CGT + dividend tax
  calculations, configurable rate tables per tax year (pre-seeded 2024/25 and
  2025/26), editable in-app so Budget changes need no rebuild
- Migration 0006: tax_rate_configs, tax_profiles, payslips, manual_cgt_disposals
  with RLS; seeds 2025/2026 rate configs for existing users
- Chart tooltip fix: all Recharts tooltips now use TOOLTIP_STYLE constant so
  they render correctly across all dark/light themes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
megaproxy 2026-04-23 21:40:02 +00:00
parent 0b326cbd87
commit afb5e99bb2
48 changed files with 6238 additions and 39 deletions

View file

@ -85,6 +85,10 @@ async def register_user(db: AsyncSession, email: str, password: str, display_nam
)
db.add(user)
await db.flush()
from app.services.tax_service import seed_default_rates
await seed_default_rates(db, user.id)
return user

View file

@ -0,0 +1,194 @@
from __future__ import annotations
import re
import uuid
from collections import defaultdict
from datetime import date, datetime, timezone
from decimal import Decimal
from statistics import mean, stdev
from dateutil.relativedelta import relativedelta
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decrypt_field
from app.db.models.transaction import Transaction
MIN_OCCURRENCES = 2
# (label, min_days, max_days, tolerance_days)
_FREQUENCIES = [
("weekly", 6, 8, 2),
("fortnightly", 13, 15, 3),
("monthly", 26, 35, 5),
("quarterly", 85, 95, 10),
("yearly", 355, 375, 15),
]
_STRIP_PREFIXES = re.compile(
r"^(direct debit|standing order|faster payment|bacs|dd|so)\s+",
re.IGNORECASE,
)
_STRIP_REFS = re.compile(r"\s+\d{5,}$")
_STRIP_DATE_PATTERNS = re.compile(
r"\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b"
r"|\b\d{1,2}(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)\b",
re.IGNORECASE,
)
_COLLAPSE_SPACES = re.compile(r"\s{2,}")
def normalise_description(raw: str) -> str:
s = raw.lower().strip()
s = _STRIP_PREFIXES.sub("", s)
s = _STRIP_DATE_PATTERNS.sub("", s)
s = _STRIP_REFS.sub("", s)
s = _COLLAPSE_SPACES.sub(" ", s).strip()
return s
def classify_frequency(avg_days: float) -> tuple[str, int] | None:
"""Return (label, expected_days) or None if avg_days matches no known frequency."""
for label, lo, hi, _ in _FREQUENCIES:
if lo <= avg_days <= hi:
return label, round(avg_days)
return None
def _within_tolerance(intervals: list[int], avg: float, frequency: str) -> bool:
tolerance = next(t for label, _, _, t in _FREQUENCIES if label == frequency)
return all(abs(iv - avg) <= tolerance for iv in intervals)
def next_expected_date(last: date, frequency: str) -> date:
delta_map = {
"weekly": relativedelta(weeks=1),
"fortnightly": relativedelta(weeks=2),
"monthly": relativedelta(months=1),
"quarterly": relativedelta(months=3),
"yearly": relativedelta(years=1),
}
delta = delta_map[frequency]
result = last + delta
today = date.today()
while result < today:
result += delta
return result
def _confidence(intervals: list[int], expected: int) -> float:
if len(intervals) < 2:
return 1.0
try:
sd = stdev(intervals)
except Exception:
sd = 0.0
conf = 1.0 - (sd / expected) if expected > 0 else 0.0
return round(max(0.0, min(1.0, conf)), 4)
async def detect_recurring(db: AsyncSession, user_id: uuid.UUID) -> dict:
"""
Scan all transactions for a user, detect recurring patterns, and tag them.
Skips transactions where recurring_rule.manually_set == true.
Returns {"newly_tagged": int, "total_recurring": int}.
"""
result = await db.execute(
select(Transaction).where(
Transaction.user_id == user_id,
Transaction.deleted_at.is_(None),
)
)
transactions = result.scalars().all()
txn_map: dict[uuid.UUID, Transaction] = {t.id: t for t in transactions}
# Group by (normalised_description, exact_amount)
keyed: dict[tuple[str, Decimal], tuple[list[date], list[uuid.UUID]]] = defaultdict(
lambda: ([], [])
)
for txn in transactions:
try:
desc = decrypt_field(txn.description_enc) or ""
except Exception:
desc = ""
norm = normalise_description(desc)
amount = txn.amount.quantize(Decimal("0.01"))
dates_list, ids_list = keyed[(norm, amount)]
dates_list.append(txn.date)
ids_list.append(txn.id)
now = datetime.now(timezone.utc)
newly_tagged = 0
total_recurring = 0
matched_ids: set[uuid.UUID] = set()
for (norm_desc, amount), (dates_list, ids_list) in keyed.items():
if len(dates_list) < MIN_OCCURRENCES:
continue
paired = sorted(zip(dates_list, ids_list), key=lambda x: x[0])
sorted_dates = [p[0] for p in paired]
sorted_ids = [p[1] for p in paired]
intervals = [
(sorted_dates[i + 1] - sorted_dates[i]).days
for i in range(len(sorted_dates) - 1)
]
avg = mean(intervals)
freq_result = classify_frequency(avg)
if freq_result is None:
continue
frequency, expected_days = freq_result
if not _within_tolerance(intervals, avg, frequency):
continue
conf = _confidence(intervals, expected_days)
last_date = sorted_dates[-1]
next_date = next_expected_date(last_date, frequency)
for txn_id in sorted_ids:
txn = txn_map.get(txn_id)
if txn is None:
continue
matched_ids.add(txn_id)
rule = txn.recurring_rule or {}
if rule.get("manually_set"):
if txn.is_recurring:
total_recurring += 1
continue
was_recurring = txn.is_recurring
txn.is_recurring = True
txn.recurring_rule = {
"frequency": frequency,
"typical_amount": float(amount),
"next_expected": next_date.isoformat(),
"last_paid": last_date.isoformat(),
"confidence": conf,
"detected_at": now.isoformat(),
"manually_set": False,
}
txn.updated_at = now
if not was_recurring:
newly_tagged += 1
total_recurring += 1
# Un-tag previously auto-detected transactions whose pattern no longer matches
for txn in transactions:
if not txn.is_recurring:
continue
rule = txn.recurring_rule or {}
if rule.get("manually_set"):
continue
if txn.id not in matched_ids:
txn.is_recurring = False
txn.recurring_rule = None
txn.updated_at = now
await db.flush()
return {"newly_tagged": newly_tagged, "total_recurring": total_recurring}

View file

@ -0,0 +1,301 @@
"""
Pure UK tax calculation functions. Zero external dependencies no DB, no ORM.
Each function receives a pre-loaded `rates` dict so they are fully unit-testable.
Tax year convention: tax_year=N means 6 Apr (N-1) 5 Apr N.
"""
from __future__ import annotations
import re
from datetime import date
from decimal import ROUND_HALF_UP, Decimal
from typing import Any
# ---------------------------------------------------------------------------
# Tax year helpers
# ---------------------------------------------------------------------------
def tax_year_for_date(d: date) -> int:
"""Return tax_year int for a calendar date. tax_year=N = 6 Apr (N-1) → 5 Apr N."""
if (d.month, d.day) >= (4, 6):
return d.year + 1
return d.year
def tax_year_date_range(tax_year: int) -> tuple[date, date]:
"""Return (start_date, end_date) inclusive for the given tax year."""
return date(tax_year - 1, 4, 6), date(tax_year, 4, 5)
# ---------------------------------------------------------------------------
# Tax code parser
# ---------------------------------------------------------------------------
def parse_tax_code(code: str) -> dict[str, Any]:
"""Parse a UK tax code string.
Returns:
allowance annual personal allowance in £ (negative for K codes)
rate_override flat rate (0.01.0) if code fixes a single rate, else None
k_code True if K prefix (negative allowance)
no_tax True if NT code
"""
raw = code.strip().upper()
raw = re.sub(r"[/\s]?(W1|M1)$", "", raw)
if raw == "NT":
return {"allowance": Decimal("0"), "rate_override": Decimal("0"), "k_code": False, "no_tax": True}
if raw == "BR":
return {"allowance": Decimal("0"), "rate_override": Decimal("0.20"), "k_code": False, "no_tax": False}
if raw == "D0":
return {"allowance": Decimal("0"), "rate_override": Decimal("0.40"), "k_code": False, "no_tax": False}
if raw == "D1":
return {"allowance": Decimal("0"), "rate_override": Decimal("0.45"), "k_code": False, "no_tax": False}
if raw == "0T":
return {"allowance": Decimal("0"), "rate_override": None, "k_code": False, "no_tax": False}
k_match = re.fullmatch(r"K(\d+)", raw)
if k_match:
return {"allowance": -Decimal(k_match.group(1)) * 10, "rate_override": None, "k_code": True, "no_tax": False}
std_match = re.fullmatch(r"(\d+)[LMNTY]?", raw)
if std_match:
return {"allowance": Decimal(std_match.group(1)) * 10, "rate_override": None, "k_code": False, "no_tax": False}
# Unknown code — treat as 0T
return {"allowance": Decimal("0"), "rate_override": None, "k_code": False, "no_tax": False}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _apply_bands(amount: Decimal, bands: list[dict]) -> tuple[Decimal, list[dict]]:
total = Decimal("0")
breakdown = []
for band in bands:
band_from = Decimal(str(band["from"]))
band_to = Decimal(str(band["to"])) if band["to"] is not None else None
rate = Decimal(str(band["rate"]))
if amount <= band_from:
break
upper = min(amount, band_to) if band_to is not None else amount
taxable_in_band = upper - band_from
tax_in_band = (taxable_in_band * rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
total += tax_in_band
if taxable_in_band > 0:
breakdown.append({
"from": int(band_from),
"to": int(band_to) if band_to is not None else None,
"rate": float(rate),
"taxable": float(taxable_in_band),
"tax": float(tax_in_band),
})
return total, breakdown
def _personal_allowance_tapered(base_allowance: Decimal, gross_income: Decimal) -> Decimal:
"""Reduce PA by £1 per £2 over £100,000; floor at zero at £125,140."""
taper_threshold = Decimal("100000")
if gross_income <= taper_threshold:
return base_allowance
reduction = ((gross_income - taper_threshold) / 2).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
return max(Decimal("0"), base_allowance - reduction)
# ---------------------------------------------------------------------------
# Core calculation functions
# ---------------------------------------------------------------------------
def calculate_income_tax(
gross_income: Decimal,
tax_code: str,
rates: dict,
) -> dict[str, Any]:
"""Calculate income tax liability and remaining basic-rate band.
Bands are applied to GROSS income. The 0% band threshold is adjusted to match
the actual personal allowance from the tax code (with taper applied if applicable).
K codes add their amount to gross income before applying the standard bands.
Returns:
personal_allowance, taxable_income, liability, band_breakdown,
remaining_basic_rate_band (passed downstream to CGT/dividend calculations)
"""
parsed = parse_tax_code(tax_code)
bands = rates["income_tax"]["bands"]
# These are gross-income thresholds from the band definitions
pa_threshold = Decimal(str(next(b["to"] for b in bands if b["rate"] == 0.00)))
basic_rate_upper = Decimal(str(next(b["to"] for b in bands if b["rate"] == 0.20)))
if parsed["no_tax"]:
return {
"personal_allowance": pa_threshold,
"taxable_income": Decimal("0"),
"liability": Decimal("0"),
"band_breakdown": [],
"remaining_basic_rate_band": max(Decimal("0"), basic_rate_upper - gross_income),
}
if parsed["rate_override"] is not None:
liability = (gross_income * parsed["rate_override"]).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
return {
"personal_allowance": Decimal("0"),
"taxable_income": gross_income,
"liability": liability,
"band_breakdown": [{"rate": float(parsed["rate_override"]), "tax": float(liability)}],
"remaining_basic_rate_band": Decimal("0"),
}
if parsed["k_code"]:
# K codes: the K amount adds to taxable base; standard PA band still applies to effective gross.
# effective_gross is the "notional income" HMRC uses to calculate tax.
k_amount = abs(parsed["allowance"])
effective_gross = gross_income + k_amount
personal_allowance = Decimal("0") # K code replaces any standard PA grant
taxable_income = effective_gross # reported as the effective taxable base
liability, band_breakdown = _apply_bands(effective_gross, bands)
remaining_brb = max(Decimal("0"), basic_rate_upper - effective_gross)
else:
base_pa = parsed["allowance"]
personal_allowance = _personal_allowance_tapered(base_pa, gross_income) if base_pa > 0 else base_pa
# Adjust the 0% band to match the actual personal allowance, then apply to gross income.
adjusted_bands = [
{"from": 0, "to": float(personal_allowance), "rate": 0.00} if b["rate"] == 0.00 else b
for b in bands
]
taxable_income = max(Decimal("0"), gross_income - personal_allowance)
liability, band_breakdown = _apply_bands(gross_income, adjusted_bands)
remaining_brb = max(Decimal("0"), basic_rate_upper - gross_income)
return {
"personal_allowance": personal_allowance,
"taxable_income": taxable_income,
"liability": liability,
"band_breakdown": band_breakdown,
"remaining_basic_rate_band": remaining_brb,
}
def calculate_ni(gross_income: Decimal, rates: dict) -> dict[str, Any]:
"""Calculate primary Class 1 NI liability."""
bands = rates["ni"]["bands"]
liability, band_breakdown = _apply_bands(gross_income, bands)
return {"liability": liability, "band_breakdown": band_breakdown}
def calculate_cgt(
total_gain: Decimal,
remaining_basic_rate_band: Decimal,
rates: dict,
) -> dict[str, Any]:
"""Calculate CGT liability.
Gains within the remaining basic-rate band are taxed at basic_rate;
gains above it at higher_rate. Annual exempt amount applied first.
"""
cgt_rates = rates["cgt"]
exempt = Decimal(str(cgt_rates["exempt"]))
basic_rate = Decimal(str(cgt_rates["basic_rate"]))
higher_rate = Decimal(str(cgt_rates["higher_rate"]))
taxable_gain = max(Decimal("0"), total_gain - exempt)
if taxable_gain == 0:
return {
"gross_gain": total_gain,
"exempt": min(total_gain, exempt) if total_gain > 0 else Decimal("0"),
"taxable_gain": Decimal("0"),
"liability": Decimal("0"),
"band_breakdown": [],
}
basic_portion = min(taxable_gain, remaining_basic_rate_band)
higher_portion = taxable_gain - basic_portion
liability = (
(basic_portion * basic_rate) + (higher_portion * higher_rate)
).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
breakdown = []
if basic_portion > 0:
breakdown.append({"rate": float(basic_rate), "taxable": float(basic_portion),
"tax": float((basic_portion * basic_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
if higher_portion > 0:
breakdown.append({"rate": float(higher_rate), "taxable": float(higher_portion),
"tax": float((higher_portion * higher_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
return {
"gross_gain": total_gain,
"exempt": min(total_gain, exempt),
"taxable_gain": taxable_gain,
"liability": liability,
"band_breakdown": breakdown,
}
def calculate_dividend_tax(
total_dividends: Decimal,
remaining_basic_rate_band: Decimal,
rates: dict,
) -> dict[str, Any]:
"""Calculate dividend tax liability.
Dividend allowance applied first; taxable dividends are then slotted into
the remaining income bands to determine which rate applies.
"""
div_rates = rates["dividend"]
allowance = Decimal(str(div_rates["allowance"]))
basic_rate = Decimal(str(div_rates["basic_rate"]))
higher_rate = Decimal(str(div_rates["higher_rate"]))
additional_rate = Decimal(str(div_rates["additional_rate"]))
taxable_dividends = max(Decimal("0"), total_dividends - allowance)
if taxable_dividends == 0:
return {
"gross_dividends": total_dividends,
"allowance": min(total_dividends, allowance),
"taxable_dividends": Decimal("0"),
"liability": Decimal("0"),
"band_breakdown": [],
}
basic_portion = min(taxable_dividends, remaining_basic_rate_band)
remainder = taxable_dividends - basic_portion
higher_upper = Decimal(str(
next(b["to"] for b in rates["income_tax"]["bands"] if b["rate"] == 0.40)
))
higher_portion = min(remainder, higher_upper)
additional_portion = remainder - higher_portion
liability = (
(basic_portion * basic_rate)
+ (higher_portion * higher_rate)
+ (additional_portion * additional_rate)
).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
breakdown = []
if basic_portion > 0:
breakdown.append({"rate": float(basic_rate), "taxable": float(basic_portion),
"tax": float((basic_portion * basic_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
if higher_portion > 0:
breakdown.append({"rate": float(higher_rate), "taxable": float(higher_portion),
"tax": float((higher_portion * higher_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
if additional_portion > 0:
breakdown.append({"rate": float(additional_rate), "taxable": float(additional_portion),
"tax": float((additional_portion * additional_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
return {
"gross_dividends": total_dividends,
"allowance": min(total_dividends, allowance),
"taxable_dividends": taxable_dividends,
"liability": liability,
"band_breakdown": breakdown,
}

View file

@ -0,0 +1,744 @@
"""
UK tax service layer for MyMidas.
Pure calculation functions live in tax_calculations.py (no DB imports).
This module provides the DB-backed service layer: rate loading, CRUD, report builder.
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from typing import Any
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decrypt_field, encrypt_field
from app.services.tax_calculations import ( # noqa: F401 — re-exported for callers
calculate_cgt,
calculate_dividend_tax,
calculate_income_tax,
calculate_ni,
parse_tax_code,
tax_year_date_range,
tax_year_for_date,
)
# ---------------------------------------------------------------------------
# DB helpers
# ---------------------------------------------------------------------------
async def load_rates(db: AsyncSession, user_id: uuid.UUID, tax_year: int) -> dict:
"""Load and return rate config dict for a given user/year.
Returns {"income_tax": {...}, "ni": {...}, "cgt": {...}, "dividend": {...}}
Raises ValueError if any rate type is missing for the requested year.
"""
from app.db.models.tax import TaxRateConfig
result = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
)
)
rows = list(result.scalars())
if not rows:
raise ValueError(f"No tax rate config found for year {tax_year}. Please configure rates first.")
rates: dict = {}
for row in rows:
rates[row.rate_type] = row.config
required = {"income_tax", "ni", "cgt", "dividend"}
missing = required - rates.keys()
if missing:
raise ValueError(f"Incomplete tax rate config for year {tax_year}: missing {missing}")
return rates
async def seed_default_rates(db: AsyncSession, user_id: uuid.UUID) -> None:
"""Insert default 2025 and 2026 rate configs for a newly registered user."""
from app.db.models.tax import TaxRateConfig
now = datetime.now(timezone.utc)
income_tax_bands = {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.20},
{"from": 50270, "to": 125140, "rate": 0.40},
{"from": 125140, "to": None, "rate": 0.45},
]
}
ni_bands = {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.08},
{"from": 50270, "to": None, "rate": 0.02},
]
}
cgt = {"exempt": 3000, "basic_rate": 0.18, "higher_rate": 0.24}
dividend = {
"allowance": 500,
"basic_rate": 0.0875,
"higher_rate": 0.3375,
"additional_rate": 0.3935,
}
defaults = {
"income_tax": income_tax_bands,
"ni": ni_bands,
"cgt": cgt,
"dividend": dividend,
}
for tax_year in (2025, 2026):
for rate_type, config in defaults.items():
existing = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
TaxRateConfig.rate_type == rate_type,
)
)
if existing.scalar_one_or_none() is None:
db.add(TaxRateConfig(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
rate_type=rate_type,
config=config,
updated_at=now,
))
# ---------------------------------------------------------------------------
# Tax profile CRUD
# ---------------------------------------------------------------------------
async def get_tax_profile(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
):
from app.db.models.tax import TaxProfile
result = await db.execute(
select(TaxProfile).where(
TaxProfile.user_id == user_id,
TaxProfile.tax_year == tax_year,
)
)
return result.scalar_one_or_none()
async def upsert_tax_profile(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
tax_code: str,
employer_name: str | None,
is_cumulative: bool,
):
from app.db.models.tax import TaxProfile
from app.core.audit import write_audit
now = datetime.now(timezone.utc)
profile = await get_tax_profile(db, user_id, tax_year)
employer_enc = encrypt_field(employer_name) if employer_name else None
if profile is None:
profile = TaxProfile(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
tax_code=tax_code,
employer_name_enc=employer_enc,
is_cumulative=is_cumulative,
created_at=now,
updated_at=now,
)
db.add(profile)
action = "tax_profile_create"
else:
profile.tax_code = tax_code
profile.employer_name_enc = employer_enc
profile.is_cumulative = is_cumulative
profile.updated_at = now
action = "tax_profile_update"
await db.flush()
await write_audit(db, user_id=user_id, action=action,
resource_type="tax_profile", resource_id=profile.id)
return profile
def _profile_to_response(profile) -> dict:
from app.db.models.tax import TaxProfile
employer = None
if profile.employer_name_enc:
try:
employer = decrypt_field(profile.employer_name_enc)
except Exception:
employer = None
return {
"id": str(profile.id),
"tax_year": profile.tax_year,
"tax_code": profile.tax_code,
"employer_name": employer,
"is_cumulative": profile.is_cumulative,
"created_at": profile.created_at.isoformat(),
"updated_at": profile.updated_at.isoformat(),
}
# ---------------------------------------------------------------------------
# Payslip CRUD
# ---------------------------------------------------------------------------
async def list_payslips(db: AsyncSession, user_id: uuid.UUID, tax_year: int) -> list:
from app.db.models.tax import Payslip, TaxProfile
result = await db.execute(
select(Payslip)
.join(TaxProfile, Payslip.tax_profile_id == TaxProfile.id)
.where(
Payslip.user_id == user_id,
TaxProfile.tax_year == tax_year,
)
.order_by(Payslip.period_year, Payslip.period_month.nulls_last())
)
return list(result.scalars())
async def create_payslip(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
period_month: int | None,
period_year: int,
gross_pay: Decimal,
income_tax_withheld: Decimal,
ni_withheld: Decimal,
net_pay: Decimal,
is_p60: bool = False,
notes: str | None = None,
):
from app.db.models.tax import Payslip
from app.core.audit import write_audit
profile = await get_tax_profile(db, user_id, tax_year)
if profile is None:
raise ValueError(f"No tax profile for year {tax_year}. Create a profile first.")
now = datetime.now(timezone.utc)
notes_enc = encrypt_field(notes) if notes else None
payslip = Payslip(
id=uuid.uuid4(),
user_id=user_id,
tax_profile_id=profile.id,
period_month=period_month,
period_year=period_year,
gross_pay=gross_pay,
income_tax_withheld=income_tax_withheld,
ni_withheld=ni_withheld,
net_pay=net_pay,
is_p60=is_p60,
notes_enc=notes_enc,
created_at=now,
)
db.add(payslip)
await db.flush()
await write_audit(db, user_id=user_id, action="payslip_create",
resource_type="payslip", resource_id=payslip.id)
return payslip
async def update_payslip(
db: AsyncSession,
user_id: uuid.UUID,
payslip_id: uuid.UUID,
**kwargs,
):
from app.db.models.tax import Payslip
from app.core.audit import write_audit
result = await db.execute(
select(Payslip).where(Payslip.id == payslip_id, Payslip.user_id == user_id)
)
payslip = result.scalar_one_or_none()
if payslip is None:
raise ValueError("Payslip not found")
for field, value in kwargs.items():
if field == "notes":
payslip.notes_enc = encrypt_field(value) if value else None
else:
setattr(payslip, field, value)
await db.flush()
await write_audit(db, user_id=user_id, action="payslip_update",
resource_type="payslip", resource_id=payslip.id)
return payslip
async def delete_payslip(db: AsyncSession, user_id: uuid.UUID, payslip_id: uuid.UUID) -> None:
from app.db.models.tax import Payslip
from app.core.audit import write_audit
result = await db.execute(
select(Payslip).where(Payslip.id == payslip_id, Payslip.user_id == user_id)
)
payslip = result.scalar_one_or_none()
if payslip is None:
raise ValueError("Payslip not found")
await write_audit(db, user_id=user_id, action="payslip_delete",
resource_type="payslip", resource_id=payslip_id)
await db.delete(payslip)
await db.flush()
async def replace_with_p60(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
gross_pay: Decimal,
income_tax_withheld: Decimal,
ni_withheld: Decimal,
net_pay: Decimal,
) -> None:
"""Delete all existing payslips for the tax year and replace with a single P60."""
from app.db.models.tax import Payslip, TaxProfile
from app.core.audit import write_audit
profile = await get_tax_profile(db, user_id, tax_year)
if profile is None:
raise ValueError(f"No tax profile for year {tax_year}. Create a profile first.")
await db.execute(
delete(Payslip).where(
Payslip.user_id == user_id,
Payslip.tax_profile_id == profile.id,
)
)
now = datetime.now(timezone.utc)
p60 = Payslip(
id=uuid.uuid4(),
user_id=user_id,
tax_profile_id=profile.id,
period_month=None,
period_year=tax_year - 1, # P60 covers year ending 5 Apr tax_year; the employer year starts the prior calendar year
gross_pay=gross_pay,
income_tax_withheld=income_tax_withheld,
ni_withheld=ni_withheld,
net_pay=net_pay,
is_p60=True,
notes_enc=None,
created_at=now,
)
db.add(p60)
await db.flush()
await write_audit(db, user_id=user_id, action="payslip_p60_replace",
resource_type="payslip", resource_id=p60.id)
def _payslip_to_response(payslip) -> dict:
notes = None
if payslip.notes_enc:
try:
notes = decrypt_field(payslip.notes_enc)
except Exception:
notes = None
return {
"id": str(payslip.id),
"tax_profile_id": str(payslip.tax_profile_id),
"period_month": payslip.period_month,
"period_year": payslip.period_year,
"gross_pay": str(payslip.gross_pay),
"income_tax_withheld": str(payslip.income_tax_withheld),
"ni_withheld": str(payslip.ni_withheld),
"net_pay": str(payslip.net_pay),
"is_p60": payslip.is_p60,
"notes": notes,
"created_at": payslip.created_at.isoformat(),
}
# ---------------------------------------------------------------------------
# Manual CGT disposal CRUD
# ---------------------------------------------------------------------------
async def list_manual_disposals(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
) -> list:
from app.db.models.tax import ManualCGTDisposal
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.user_id == user_id,
ManualCGTDisposal.tax_year == tax_year,
).order_by(ManualCGTDisposal.disposal_date)
)
return list(result.scalars())
async def create_manual_disposal(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
disposal_date: date,
asset_description: str,
proceeds: Decimal,
cost_basis: Decimal,
notes: str | None = None,
):
from app.db.models.tax import ManualCGTDisposal
from app.core.audit import write_audit
now = datetime.now(timezone.utc)
disposal = ManualCGTDisposal(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
disposal_date=disposal_date,
asset_description_enc=encrypt_field(asset_description),
proceeds=proceeds,
cost_basis=cost_basis,
notes_enc=encrypt_field(notes) if notes else None,
created_at=now,
)
db.add(disposal)
await db.flush()
await write_audit(db, user_id=user_id, action="cgt_disposal_create",
resource_type="manual_cgt_disposal", resource_id=disposal.id)
return disposal
async def update_manual_disposal(
db: AsyncSession,
user_id: uuid.UUID,
disposal_id: uuid.UUID,
disposal_date: date,
asset_description: str,
proceeds: Decimal,
cost_basis: Decimal,
notes: str | None = None,
):
from app.db.models.tax import ManualCGTDisposal
from app.core.audit import write_audit
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.id == disposal_id,
ManualCGTDisposal.user_id == user_id,
)
)
disposal = result.scalar_one_or_none()
if disposal is None:
raise ValueError("Disposal not found")
disposal.disposal_date = disposal_date
disposal.asset_description_enc = encrypt_field(asset_description)
disposal.proceeds = proceeds
disposal.cost_basis = cost_basis
disposal.notes_enc = encrypt_field(notes) if notes else None
await db.flush()
await write_audit(db, user_id=user_id, action="cgt_disposal_update",
resource_type="manual_cgt_disposal", resource_id=disposal.id)
return disposal
async def delete_manual_disposal(
db: AsyncSession, user_id: uuid.UUID, disposal_id: uuid.UUID
) -> None:
from app.db.models.tax import ManualCGTDisposal
from app.core.audit import write_audit
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.id == disposal_id,
ManualCGTDisposal.user_id == user_id,
)
)
disposal = result.scalar_one_or_none()
if disposal is None:
raise ValueError("Disposal not found")
await write_audit(db, user_id=user_id, action="cgt_disposal_delete",
resource_type="manual_cgt_disposal", resource_id=disposal_id)
await db.delete(disposal)
await db.flush()
def _disposal_to_response(disposal) -> dict:
asset_desc = ""
try:
asset_desc = decrypt_field(disposal.asset_description_enc)
except Exception:
pass
notes = None
if disposal.notes_enc:
try:
notes = decrypt_field(disposal.notes_enc)
except Exception:
pass
gain_loss = disposal.proceeds - disposal.cost_basis
return {
"id": str(disposal.id),
"tax_year": disposal.tax_year,
"disposal_date": disposal.disposal_date.isoformat(),
"asset_description": asset_desc,
"proceeds": str(disposal.proceeds),
"cost_basis": str(disposal.cost_basis),
"gain_loss": str(gain_loss),
"notes": notes,
"created_at": disposal.created_at.isoformat(),
}
# ---------------------------------------------------------------------------
# Tax rate config CRUD
# ---------------------------------------------------------------------------
async def list_configured_years(db: AsyncSession, user_id: uuid.UUID) -> list[int]:
from app.db.models.tax import TaxRateConfig
result = await db.execute(
select(TaxRateConfig.tax_year)
.where(TaxRateConfig.user_id == user_id)
.distinct()
.order_by(TaxRateConfig.tax_year)
)
return [row[0] for row in result]
async def get_rate_config(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
) -> dict:
from app.db.models.tax import TaxRateConfig
result = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
)
)
rows = list(result.scalars())
if not rows:
raise ValueError(f"No rate config for year {tax_year}")
return {
"tax_year": tax_year,
"rates": {row.rate_type: row.config for row in rows},
"updated_at": max(row.updated_at for row in rows).isoformat(),
}
async def upsert_rate_config(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
rates: dict,
) -> dict:
from app.db.models.tax import TaxRateConfig
from app.core.audit import write_audit
now = datetime.now(timezone.utc)
for rate_type, config in rates.items():
result = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
TaxRateConfig.rate_type == rate_type,
)
)
row = result.scalar_one_or_none()
if row is None:
db.add(TaxRateConfig(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
rate_type=rate_type,
config=config,
updated_at=now,
))
else:
row.config = config
row.updated_at = now
await db.flush()
await write_audit(db, user_id=user_id, action="tax_rate_config_update",
resource_type="tax_rate_config",
resource_id=None,
metadata={"tax_year": tax_year})
return await get_rate_config(db, user_id, tax_year)
# ---------------------------------------------------------------------------
# Report builder
# ---------------------------------------------------------------------------
async def build_tax_report(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
) -> dict[str, Any]:
"""Build the full tax report for a given year.
Steps:
1. Load rates
2. Load profile + payslip totals
3. Load investment sell disposals within the tax year
4. Load investment dividend transactions within the tax year
5. Load manual CGT disposals
6. Run calculations: income tax NI CGT dividend tax
7. Return full report dict
"""
from app.db.models.tax import ManualCGTDisposal, Payslip, TaxProfile
from app.db.models.investment_transaction import InvestmentTransaction
from app.db.models.investment_holding import InvestmentHolding
from app.db.models.asset import Asset
rates = await load_rates(db, user_id, tax_year)
start_date, end_date = tax_year_date_range(tax_year)
# ---- Profile ----
profile = await get_tax_profile(db, user_id, tax_year)
tax_code = profile.tax_code if profile else "1257L"
profile_data = _profile_to_response(profile) if profile else None
# ---- Payslip totals ----
payslips = await list_payslips(db, user_id, tax_year)
gross_income = sum((Decimal(str(p.gross_pay)) for p in payslips), Decimal("0"))
income_tax_withheld = sum((Decimal(str(p.income_tax_withheld)) for p in payslips), Decimal("0"))
ni_withheld = sum((Decimal(str(p.ni_withheld)) for p in payslips), Decimal("0"))
payslip_rows = [_payslip_to_response(p) for p in payslips]
# ---- Investment sell disposals ----
inv_disposals_result = await db.execute(
select(InvestmentTransaction, InvestmentHolding, Asset)
.join(InvestmentHolding, InvestmentTransaction.holding_id == InvestmentHolding.id)
.join(Asset, InvestmentHolding.asset_id == Asset.id)
.where(
InvestmentTransaction.user_id == user_id,
InvestmentTransaction.type == "sell",
InvestmentTransaction.date >= start_date,
InvestmentTransaction.date <= end_date,
)
.order_by(InvestmentTransaction.date)
)
inv_disposal_rows = []
total_inv_gain = Decimal("0")
for inv_txn, holding, asset in inv_disposals_result:
proceeds = Decimal(str(inv_txn.total_amount))
cost = Decimal(str(inv_txn.quantity)) * Decimal(str(holding.avg_cost_basis))
gain = proceeds - cost - Decimal(str(inv_txn.fees))
total_inv_gain += gain
inv_disposal_rows.append({
"date": inv_txn.date.isoformat(),
"asset": asset.name,
"symbol": asset.symbol,
"quantity": str(inv_txn.quantity),
"proceeds": str(proceeds),
"cost_basis": str(cost),
"fees": str(inv_txn.fees),
"gain_loss": str(gain),
})
# ---- Manual CGT disposals ----
manual_disposals = await list_manual_disposals(db, user_id, tax_year)
manual_disposal_rows = [_disposal_to_response(d) for d in manual_disposals]
total_manual_gain = sum(
(Decimal(str(d.proceeds)) - Decimal(str(d.cost_basis)) for d in manual_disposals),
Decimal("0"),
)
total_cgt_gain = total_inv_gain + total_manual_gain
# ---- Investment dividends ----
div_result = await db.execute(
select(InvestmentTransaction, InvestmentHolding, Asset)
.join(InvestmentHolding, InvestmentTransaction.holding_id == InvestmentHolding.id)
.join(Asset, InvestmentHolding.asset_id == Asset.id)
.where(
InvestmentTransaction.user_id == user_id,
InvestmentTransaction.type == "dividend",
InvestmentTransaction.date >= start_date,
InvestmentTransaction.date <= end_date,
)
.order_by(InvestmentTransaction.date)
)
dividend_rows = []
total_dividends = Decimal("0")
for inv_txn, holding, asset in div_result:
amount = Decimal(str(inv_txn.total_amount))
total_dividends += amount
dividend_rows.append({
"date": inv_txn.date.isoformat(),
"asset": asset.name,
"symbol": asset.symbol,
"amount": str(amount),
})
# ---- Calculations ----
income_tax_result = calculate_income_tax(gross_income, tax_code, rates)
ni_result = calculate_ni(gross_income, rates)
cgt_result = calculate_cgt(
total_cgt_gain,
income_tax_result["remaining_basic_rate_band"],
rates,
)
dividend_result = calculate_dividend_tax(
total_dividends,
income_tax_result["remaining_basic_rate_band"],
rates,
)
# ---- Totals ----
total_liability = (
income_tax_result["liability"]
+ ni_result["liability"]
+ cgt_result["liability"]
+ dividend_result["liability"]
)
total_withheld = income_tax_withheld + ni_withheld
net_owed = total_liability - total_withheld
return {
"tax_year": tax_year,
"tax_year_display": f"{tax_year - 1}/{str(tax_year)[2:]}",
"profile": profile_data,
"income": {
"gross_income": str(gross_income),
"income_tax_withheld": str(income_tax_withheld),
"ni_withheld": str(ni_withheld),
"payslips": payslip_rows,
},
"income_tax": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in income_tax_result.items()
if k != "remaining_basic_rate_band"},
"withheld": str(income_tax_withheld),
"owed": str(income_tax_result["liability"] - income_tax_withheld),
},
"ni": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in ni_result.items()},
"withheld": str(ni_withheld),
"owed": str(ni_result["liability"] - ni_withheld),
},
"cgt": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in cgt_result.items()},
"investment_disposals": inv_disposal_rows,
"manual_disposals": manual_disposal_rows,
"total_gain": str(total_cgt_gain),
},
"dividends": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in dividend_result.items()},
"dividend_transactions": dividend_rows,
},
"summary": {
"total_liability": str(total_liability),
"total_withheld": str(total_withheld),
"net_owed": str(net_owed),
"overpaid": net_owed < 0,
},
}

View file

@ -47,6 +47,7 @@ def _to_response(t: Transaction) -> dict:
"notes": _dec(t.notes_enc),
"tags": t.tags or [],
"is_recurring": t.is_recurring,
"recurring_rule": t.recurring_rule,
"attachment_refs": t.attachment_refs or [],
"created_at": t.created_at,
"updated_at": t.updated_at,
@ -221,6 +222,10 @@ async def update_transaction(
txn.notes_enc = _enc(data.notes)
if data.tags is not None:
txn.tags = data.tags
if data.is_recurring is not None:
txn.is_recurring = data.is_recurring
if data.recurring_rule is not None:
txn.recurring_rule = data.recurring_rule
txn.updated_at = now
await db.flush()
@ -307,5 +312,7 @@ async def import_csv(
await db.flush()
if imported > 0:
await recalculate_balance(db, account_id)
from app.services.recurring_service import detect_recurring
await detect_recurring(db, user_id)
return {"imported": imported, "skipped": skipped}