Add recurring transaction detection, subscriptions page, and UK tax reporting

- Recurring service: auto-detects direct debits/subscriptions from CSV imports
  using frequency analysis; manual toggle in transaction detail drawer
- Subscriptions page (/subscriptions): groups recurring payments with monthly
  cost equivalents, next-payment badges, and re-scan trigger
- UK Tax page (/tax): payslips/P60 entry, income tax + NI + CGT + dividend tax
  calculations, configurable rate tables per tax year (pre-seeded 2024/25 and
  2025/26), editable in-app so Budget changes need no rebuild
- Migration 0006: tax_rate_configs, tax_profiles, payslips, manual_cgt_disposals
  with RLS; seeds 2025/2026 rate configs for existing users
- Chart tooltip fix: all Recharts tooltips now use TOOLTIP_STYLE constant so
  they render correctly across all dark/light themes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
megaproxy 2026-04-23 21:40:02 +00:00
parent 0b326cbd87
commit afb5e99bb2
48 changed files with 6238 additions and 39 deletions

View file

@ -0,0 +1,186 @@
"""add tax tables
Revision ID: 0006
Revises: 0005
Create Date: 2026-04-23
"""
import uuid
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "0006"
down_revision = "0005"
branch_labels = None
depends_on = None
# ---------------------------------------------------------------------------
# Seed data for 2025 and 2026
# ---------------------------------------------------------------------------
_INCOME_TAX_BANDS_2025 = {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.20},
{"from": 50270, "to": 125140, "rate": 0.40},
{"from": 125140, "to": None, "rate": 0.45},
]
}
_NI_BANDS_2025 = {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.08},
{"from": 50270, "to": None, "rate": 0.02},
]
}
_CGT_2025 = {"exempt": 3000, "basic_rate": 0.18, "higher_rate": 0.24}
_DIVIDEND_2025 = {
"allowance": 500,
"basic_rate": 0.0875,
"higher_rate": 0.3375,
"additional_rate": 0.3935,
}
# 2026 thresholds remain frozen; rates unchanged from 2025
_SEED = {
2025: {
"income_tax": _INCOME_TAX_BANDS_2025,
"ni": _NI_BANDS_2025,
"cgt": _CGT_2025,
"dividend": _DIVIDEND_2025,
},
2026: {
"income_tax": _INCOME_TAX_BANDS_2025,
"ni": _NI_BANDS_2025,
"cgt": _CGT_2025,
"dividend": _DIVIDEND_2025,
},
}
def upgrade() -> None:
# ------------------------------------------------------------------
# tax_rate_configs
# ------------------------------------------------------------------
op.create_table(
"tax_rate_configs",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("tax_year", sa.Integer, nullable=False),
sa.Column("rate_type", sa.String(30), nullable=False),
sa.Column("config", postgresql.JSONB, nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_unique_constraint(
"uq_tax_rate_configs_user_year_type",
"tax_rate_configs",
["user_id", "tax_year", "rate_type"],
)
op.create_index("ix_tax_rate_configs_user_id", "tax_rate_configs", ["user_id"])
# ------------------------------------------------------------------
# tax_profiles
# ------------------------------------------------------------------
op.create_table(
"tax_profiles",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("tax_year", sa.Integer, nullable=False),
sa.Column("employer_name_enc", sa.LargeBinary, nullable=True),
sa.Column("tax_code", sa.String(20), nullable=False, server_default="1257L"),
sa.Column("is_cumulative", sa.Boolean, nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_unique_constraint(
"uq_tax_profiles_user_year",
"tax_profiles",
["user_id", "tax_year"],
)
op.create_index("ix_tax_profiles_user_id", "tax_profiles", ["user_id"])
# ------------------------------------------------------------------
# payslips
# ------------------------------------------------------------------
op.create_table(
"payslips",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("tax_profile_id", postgresql.UUID(as_uuid=True),
sa.ForeignKey("tax_profiles.id", ondelete="CASCADE"), nullable=False),
sa.Column("period_month", sa.SmallInteger, nullable=True),
sa.Column("period_year", sa.SmallInteger, nullable=False),
sa.Column("gross_pay", sa.Numeric(14, 2), nullable=False),
sa.Column("income_tax_withheld", sa.Numeric(14, 2), nullable=False),
sa.Column("ni_withheld", sa.Numeric(14, 2), nullable=False),
sa.Column("net_pay", sa.Numeric(14, 2), nullable=False),
sa.Column("is_p60", sa.Boolean, nullable=False, server_default="false"),
sa.Column("notes_enc", sa.LargeBinary, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_payslips_user_id", "payslips", ["user_id"])
op.create_index("ix_payslips_tax_profile_id", "payslips", ["tax_profile_id"])
# ------------------------------------------------------------------
# manual_cgt_disposals
# ------------------------------------------------------------------
op.create_table(
"manual_cgt_disposals",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("tax_year", sa.Integer, nullable=False),
sa.Column("disposal_date", sa.Date, nullable=False),
sa.Column("asset_description_enc", sa.LargeBinary, nullable=False),
sa.Column("proceeds", sa.Numeric(14, 2), nullable=False),
sa.Column("cost_basis", sa.Numeric(14, 2), nullable=False),
sa.Column("notes_enc", sa.LargeBinary, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_manual_cgt_disposals_user_id", "manual_cgt_disposals", ["user_id"])
# ------------------------------------------------------------------
# RLS
# ------------------------------------------------------------------
for table in ["tax_rate_configs", "tax_profiles", "payslips", "manual_cgt_disposals"]:
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY {table}_user_isolation ON {table}
USING (user_id = current_app_user_id())
""")
# ------------------------------------------------------------------
# Seed 2025 + 2026 rate configs for all existing users
# ------------------------------------------------------------------
import json
from datetime import datetime, timezone
now = datetime.now(timezone.utc).isoformat()
for tax_year, rate_types in _SEED.items():
for rate_type, config in rate_types.items():
op.execute(sa.text("""
INSERT INTO tax_rate_configs (id, user_id, tax_year, rate_type, config, updated_at)
SELECT gen_random_uuid(), id, :tax_year, :rate_type, CAST(:config AS jsonb), CAST(:updated_at AS timestamptz)
FROM users
WHERE deleted_at IS NULL
ON CONFLICT (user_id, tax_year, rate_type) DO NOTHING
""").bindparams(
tax_year=tax_year,
rate_type=rate_type,
config=json.dumps(config),
updated_at=now,
))
def downgrade() -> None:
for table in ["tax_rate_configs", "tax_profiles", "payslips", "manual_cgt_disposals"]:
op.execute(f"DROP POLICY IF EXISTS {table}_user_isolation ON {table}")
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
op.drop_table("manual_cgt_disposals")
op.drop_table("payslips")
op.drop_table("tax_profiles")
op.drop_table("tax_rate_configs")

View file

@ -1,6 +1,6 @@
from fastapi import APIRouter
from app.api.v1 import auth, users, accounts, categories, transactions, budgets, reports, investments, predictions, admin, settings
from app.api.v1 import auth, users, accounts, categories, transactions, budgets, reports, investments, predictions, admin, settings, subscriptions, tax
router = APIRouter()
router.include_router(auth.router, prefix="/auth", tags=["auth"])
@ -14,3 +14,5 @@ router.include_router(investments.router)
router.include_router(predictions.router)
router.include_router(admin.router)
router.include_router(settings.router)
router.include_router(subscriptions.router)
router.include_router(tax.router)

View file

@ -0,0 +1,123 @@
from __future__ import annotations
from collections import defaultdict
from decimal import Decimal
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decrypt_field
from app.db.models.account import Account
from app.db.models.transaction import Transaction
from app.dependencies import get_current_user, get_db
router = APIRouter(prefix="/subscriptions", tags=["subscriptions"])
_MONTHLY_FACTORS = {
"weekly": Decimal("52") / Decimal("12"),
"fortnightly": Decimal("26") / Decimal("12"),
"monthly": Decimal("1"),
"quarterly": Decimal("1") / Decimal("3"),
"yearly": Decimal("1") / Decimal("12"),
}
@router.get("")
async def get_subscriptions(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
"""Return all detected recurring transactions grouped as subscriptions."""
txn_result = await db.execute(
select(Transaction).where(
Transaction.user_id == user.id,
Transaction.is_recurring == True,
Transaction.deleted_at.is_(None),
)
)
transactions = txn_result.scalars().all()
# Load accounts for name lookup
acc_result = await db.execute(
select(Account).where(
Account.user_id == user.id,
Account.deleted_at.is_(None),
)
)
account_map = {a.id: a for a in acc_result.scalars().all()}
# Group by (normalised frequency+amount key from recurring_rule)
# Use (frequency, typical_amount, normalised_name) as the grouping key
# so manually-set entries with no rule still appear individually
from app.services.recurring_service import normalise_description
# Group: key → list of transactions
groups: dict[str, list[Transaction]] = defaultdict(list)
for txn in transactions:
rule = txn.recurring_rule or {}
freq = rule.get("frequency", "unknown")
amt = rule.get("typical_amount", float(txn.amount))
try:
desc = decrypt_field(txn.description_enc) or ""
except Exception:
desc = ""
norm = normalise_description(desc)
key = f"{norm}|{amt}|{freq}"
groups[key].append(txn)
subscriptions = []
total_monthly = Decimal("0")
for key, txns in groups.items():
# Use the transaction with the most recent date as the representative
txns_sorted = sorted(txns, key=lambda t: t.date, reverse=True)
latest = txns_sorted[0]
rule = latest.recurring_rule or {}
freq = rule.get("frequency", "unknown")
amount = Decimal(str(rule.get("typical_amount", float(latest.amount))))
next_expected = rule.get("next_expected")
last_paid = rule.get("last_paid") or str(latest.date)
confidence = rule.get("confidence", 1.0)
manually_set = rule.get("manually_set", False)
try:
desc = decrypt_field(latest.description_enc) or ""
except Exception:
desc = ""
account = account_map.get(latest.account_id)
try:
account_name = decrypt_field(account.name_enc) if account else None
except Exception:
account_name = None
factor = _MONTHLY_FACTORS.get(freq, Decimal("1"))
monthly_equiv = abs(amount) * factor
total_monthly += monthly_equiv
subscriptions.append({
"name": desc,
"amount": float(amount),
"frequency": freq,
"next_expected": next_expected,
"last_paid": last_paid,
"account_id": str(latest.account_id),
"account_name": account_name,
"transaction_ids": [str(t.id) for t in txns],
"latest_transaction_id": str(latest.id),
"monthly_equivalent": float(monthly_equiv.quantize(Decimal("0.01"))),
"confidence": confidence,
"manually_set": manually_set,
})
# Sort by next_expected ascending (soonest first), nulls last
subscriptions.sort(key=lambda s: s["next_expected"] or "9999-99-99")
return {
"total_monthly_equivalent": float(total_monthly.quantize(Decimal("0.01"))),
"currency": user.base_currency,
"subscriptions": subscriptions,
}

293
backend/app/api/v1/tax.py Normal file
View file

@ -0,0 +1,293 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.db.models.user import User
from app.schemas.tax import (
ManualDisposalCreate,
ManualDisposalResponse,
ManualDisposalUpdate,
P60Entry,
PayslipCreate,
PayslipResponse,
PayslipUpdate,
TaxProfileCreate,
TaxProfileResponse,
TaxRateConfigResponse,
TaxRateConfigUpdate,
TaxReportResponse,
)
from app.services import tax_service
router = APIRouter(tags=["tax"])
# ---------------------------------------------------------------------------
# Rate configs
# ---------------------------------------------------------------------------
@router.get("/tax/rate-configs", response_model=list[int])
async def list_rate_config_years(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await tax_service.list_configured_years(db, current_user.id)
@router.get("/tax/rate-configs/{tax_year}", response_model=TaxRateConfigResponse)
async def get_rate_config(
tax_year: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
return await tax_service.get_rate_config(db, current_user.id, tax_year)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.put("/tax/rate-configs/{tax_year}", response_model=TaxRateConfigResponse)
async def upsert_rate_config(
tax_year: int,
data: TaxRateConfigUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
rates = {k: v for k, v in data.model_dump().items() if v is not None}
if not rates:
raise HTTPException(status_code=422, detail="At least one rate type must be provided")
result = await tax_service.upsert_rate_config(db, current_user.id, tax_year, rates)
await db.commit()
return result
# ---------------------------------------------------------------------------
# Tax profile
# ---------------------------------------------------------------------------
@router.get("/tax/profile/{tax_year}", response_model=TaxProfileResponse)
async def get_tax_profile(
tax_year: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
profile = await tax_service.get_tax_profile(db, current_user.id, tax_year)
if profile is None:
raise HTTPException(status_code=404, detail="No tax profile for this year")
return tax_service._profile_to_response(profile)
@router.put("/tax/profile/{tax_year}", response_model=TaxProfileResponse)
async def upsert_tax_profile(
tax_year: int,
data: TaxProfileCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
profile = await tax_service.upsert_tax_profile(
db,
current_user.id,
tax_year,
tax_code=data.tax_code,
employer_name=data.employer_name,
is_cumulative=data.is_cumulative,
)
await db.commit()
return tax_service._profile_to_response(profile)
# ---------------------------------------------------------------------------
# Payslips
# ---------------------------------------------------------------------------
@router.get("/tax/payslips/{tax_year}", response_model=list[PayslipResponse])
async def list_payslips(
tax_year: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
payslips = await tax_service.list_payslips(db, current_user.id, tax_year)
return [tax_service._payslip_to_response(p) for p in payslips]
@router.post("/tax/payslips/{tax_year}", response_model=PayslipResponse, status_code=status.HTTP_201_CREATED)
async def create_payslip(
tax_year: int,
data: PayslipCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
payslip = await tax_service.create_payslip(
db,
current_user.id,
tax_year,
period_month=data.period_month,
period_year=data.period_year,
gross_pay=data.gross_pay,
income_tax_withheld=data.income_tax_withheld,
ni_withheld=data.ni_withheld,
net_pay=data.net_pay,
notes=data.notes,
)
await db.commit()
return tax_service._payslip_to_response(payslip)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.put("/tax/payslips/{payslip_id}", response_model=PayslipResponse)
async def update_payslip(
payslip_id: uuid.UUID,
data: PayslipUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
updates = {k: v for k, v in data.model_dump().items() if v is not None}
payslip = await tax_service.update_payslip(db, current_user.id, payslip_id, **updates)
await db.commit()
return tax_service._payslip_to_response(payslip)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/tax/payslips/{payslip_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_payslip(
payslip_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
await tax_service.delete_payslip(db, current_user.id, payslip_id)
await db.commit()
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/tax/payslips/{tax_year}/p60", status_code=status.HTTP_204_NO_CONTENT)
async def enter_p60(
tax_year: int,
data: P60Entry,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
await tax_service.replace_with_p60(
db,
current_user.id,
tax_year,
gross_pay=data.gross_pay,
income_tax_withheld=data.income_tax_withheld,
ni_withheld=data.ni_withheld,
net_pay=data.net_pay,
)
await db.commit()
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# ---------------------------------------------------------------------------
# Manual CGT disposals
# ---------------------------------------------------------------------------
@router.get("/tax/cgt-disposals/{tax_year}", response_model=list[ManualDisposalResponse])
async def list_cgt_disposals(
tax_year: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
disposals = await tax_service.list_manual_disposals(db, current_user.id, tax_year)
return [tax_service._disposal_to_response(d) for d in disposals]
@router.post("/tax/cgt-disposals/{tax_year}", response_model=ManualDisposalResponse, status_code=status.HTTP_201_CREATED)
async def create_cgt_disposal(
tax_year: int,
data: ManualDisposalCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
disposal = await tax_service.create_manual_disposal(
db,
current_user.id,
tax_year,
disposal_date=data.disposal_date,
asset_description=data.asset_description,
proceeds=data.proceeds,
cost_basis=data.cost_basis,
notes=data.notes,
)
await db.commit()
return tax_service._disposal_to_response(disposal)
@router.put("/tax/cgt-disposals/{disposal_id}", response_model=ManualDisposalResponse)
async def update_cgt_disposal(
disposal_id: uuid.UUID,
data: ManualDisposalUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from sqlalchemy import select
from app.db.models.tax import ManualCGTDisposal
from app.core.security import decrypt_field
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.id == disposal_id,
ManualCGTDisposal.user_id == current_user.id,
)
)
disposal = result.scalar_one_or_none()
if disposal is None:
raise HTTPException(status_code=404, detail="Disposal not found")
current_desc = decrypt_field(disposal.asset_description_enc) if disposal.asset_description_enc else ""
try:
updated = await tax_service.update_manual_disposal(
db,
current_user.id,
disposal_id,
disposal_date=data.disposal_date or disposal.disposal_date,
asset_description=data.asset_description or current_desc,
proceeds=data.proceeds if data.proceeds is not None else disposal.proceeds,
cost_basis=data.cost_basis if data.cost_basis is not None else disposal.cost_basis,
notes=data.notes,
)
await db.commit()
return tax_service._disposal_to_response(updated)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/tax/cgt-disposals/{disposal_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_cgt_disposal(
disposal_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
await tax_service.delete_manual_disposal(db, current_user.id, disposal_id)
await db.commit()
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# ---------------------------------------------------------------------------
# Tax report
# ---------------------------------------------------------------------------
@router.get("/tax/report/{tax_year}", response_model=TaxReportResponse)
async def get_tax_report(
tax_year: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
return await tax_service.build_tax_report(db, current_user.id, tax_year)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View file

@ -611,6 +611,18 @@ async def import_transactions(
return result
@router.post("/detect-recurring")
async def detect_recurring_endpoint(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
"""Manually trigger recurring transaction detection for the current user."""
from app.services.recurring_service import detect_recurring
result = await detect_recurring(db, user.id)
await db.commit()
return result
@router.get("/import/template")
async def import_template():
from fastapi.responses import Response

View file

@ -11,9 +11,11 @@ from app.db.models.investment_transaction import InvestmentTransaction
from app.db.models.currency import Currency, ExchangeRate
from app.db.models.net_worth_snapshot import NetWorthSnapshot
from app.db.models.audit_log import AuditLog
from app.db.models.tax import TaxRateConfig, TaxProfile, Payslip, ManualCGTDisposal
__all__ = [
"User", "Session", "Account", "Category", "Transaction", "Budget",
"Asset", "AssetPrice", "InvestmentHolding", "InvestmentTransaction",
"Currency", "ExchangeRate", "NetWorthSnapshot", "AuditLog",
"TaxRateConfig", "TaxProfile", "Payslip", "ManualCGTDisposal",
]

View file

@ -0,0 +1,74 @@
import uuid
from datetime import date, datetime
from decimal import Decimal
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Integer, LargeBinary, Numeric, SmallInteger, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class TaxRateConfig(Base):
__tablename__ = "tax_rate_configs"
__table_args__ = (
UniqueConstraint("user_id", "tax_year", "rate_type", name="uq_tax_rate_configs_user_year_type"),
)
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
tax_year: Mapped[int] = mapped_column(Integer, nullable=False)
rate_type: Mapped[str] = mapped_column(String(30), nullable=False) # income_tax|ni|cgt|dividend
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
class TaxProfile(Base):
__tablename__ = "tax_profiles"
__table_args__ = (
UniqueConstraint("user_id", "tax_year", name="uq_tax_profiles_user_year"),
)
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
tax_year: Mapped[int] = mapped_column(Integer, nullable=False)
employer_name_enc: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
tax_code: Mapped[str] = mapped_column(String(20), nullable=False, default="1257L")
is_cumulative: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
payslips: Mapped[list["Payslip"]] = relationship(back_populates="tax_profile", lazy="noload")
class Payslip(Base):
__tablename__ = "payslips"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
tax_profile_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("tax_profiles.id", ondelete="CASCADE"), nullable=False, index=True)
period_month: Mapped[int | None] = mapped_column(SmallInteger, nullable=True)
period_year: Mapped[int] = mapped_column(SmallInteger, nullable=False)
gross_pay: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
income_tax_withheld: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
ni_withheld: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
net_pay: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
is_p60: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
notes_enc: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
tax_profile: Mapped["TaxProfile"] = relationship(back_populates="payslips", lazy="noload")
class ManualCGTDisposal(Base):
__tablename__ = "manual_cgt_disposals"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
tax_year: Mapped[int] = mapped_column(Integer, nullable=False)
disposal_date: Mapped[date] = mapped_column(Date, nullable=False)
asset_description_enc: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
proceeds: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
cost_basis: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
notes_enc: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)

215
backend/app/schemas/tax.py Normal file
View file

@ -0,0 +1,215 @@
import uuid
from datetime import date as DateType, datetime
from decimal import Decimal
from typing import Any
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Tax rate config
# ---------------------------------------------------------------------------
class TaxRateConfigUpdate(BaseModel):
"""PUT /tax/rate-configs/{tax_year} — pass only the rate types you want to upsert."""
income_tax: dict[str, Any] | None = None
ni: dict[str, Any] | None = None
cgt: dict[str, Any] | None = None
dividend: dict[str, Any] | None = None
class TaxRateConfigResponse(BaseModel):
tax_year: int
rates: dict[str, Any]
updated_at: str
# ---------------------------------------------------------------------------
# Tax profile
# ---------------------------------------------------------------------------
class TaxProfileCreate(BaseModel):
tax_code: str = Field(default="1257L", min_length=1, max_length=20)
employer_name: str | None = Field(default=None, max_length=200)
is_cumulative: bool = True
class TaxProfileResponse(BaseModel):
id: uuid.UUID
tax_year: int
tax_code: str
employer_name: str | None
is_cumulative: bool
created_at: str
updated_at: str
# ---------------------------------------------------------------------------
# Payslips
# ---------------------------------------------------------------------------
class PayslipCreate(BaseModel):
period_month: int | None = Field(default=None, ge=1, le=12)
period_year: int = Field(..., ge=2000, le=2100)
gross_pay: Decimal = Field(..., ge=0)
income_tax_withheld: Decimal = Field(..., ge=0)
ni_withheld: Decimal = Field(..., ge=0)
net_pay: Decimal = Field(..., ge=0)
notes: str | None = None
class PayslipUpdate(BaseModel):
period_month: int | None = Field(default=None, ge=1, le=12)
period_year: int | None = Field(default=None, ge=2000, le=2100)
gross_pay: Decimal | None = Field(default=None, ge=0)
income_tax_withheld: Decimal | None = Field(default=None, ge=0)
ni_withheld: Decimal | None = Field(default=None, ge=0)
net_pay: Decimal | None = Field(default=None, ge=0)
notes: str | None = None
class PayslipResponse(BaseModel):
id: uuid.UUID
tax_profile_id: uuid.UUID
period_month: int | None
period_year: int
gross_pay: str
income_tax_withheld: str
ni_withheld: str
net_pay: str
is_p60: bool
notes: str | None
created_at: str
class P60Entry(BaseModel):
gross_pay: Decimal = Field(..., ge=0)
income_tax_withheld: Decimal = Field(..., ge=0)
ni_withheld: Decimal = Field(..., ge=0)
net_pay: Decimal = Field(..., ge=0)
# ---------------------------------------------------------------------------
# Manual CGT disposals
# ---------------------------------------------------------------------------
class ManualDisposalCreate(BaseModel):
disposal_date: DateType
asset_description: str = Field(..., min_length=1, max_length=500)
proceeds: Decimal = Field(..., ge=0)
cost_basis: Decimal = Field(..., ge=0)
notes: str | None = None
class ManualDisposalUpdate(BaseModel):
disposal_date: DateType | None = None
asset_description: str | None = Field(default=None, min_length=1, max_length=500)
proceeds: Decimal | None = Field(default=None, ge=0)
cost_basis: Decimal | None = Field(default=None, ge=0)
notes: str | None = None
class ManualDisposalResponse(BaseModel):
id: uuid.UUID
tax_year: int
disposal_date: str
asset_description: str
proceeds: str
cost_basis: str
gain_loss: str
notes: str | None
created_at: str
# ---------------------------------------------------------------------------
# Tax report (nested)
# ---------------------------------------------------------------------------
class BandBreakdownItem(BaseModel):
rate: float
taxable: float
tax: float
from_: int | None = Field(default=None, alias="from")
to: int | None = None
model_config = {"populate_by_name": True}
class IncomeTaxSummary(BaseModel):
personal_allowance: str
taxable_income: str
liability: str
band_breakdown: list[dict[str, Any]]
withheld: str
owed: str
class NISummary(BaseModel):
liability: str
band_breakdown: list[dict[str, Any]]
withheld: str
owed: str
class InvestmentDisposalItem(BaseModel):
date: str
asset: str
symbol: str
quantity: str
proceeds: str
cost_basis: str
fees: str
gain_loss: str
class CGTSummary(BaseModel):
gross_gain: str
exempt: str
taxable_gain: str
liability: str
band_breakdown: list[dict[str, Any]]
investment_disposals: list[dict[str, Any]]
manual_disposals: list[dict[str, Any]]
total_gain: str
class DividendTransactionItem(BaseModel):
date: str
asset: str
symbol: str
amount: str
class DividendSummary(BaseModel):
gross_dividends: str
allowance: str
taxable_dividends: str
liability: str
band_breakdown: list[dict[str, Any]]
dividend_transactions: list[dict[str, Any]]
class TaxReportSummary(BaseModel):
total_liability: str
total_withheld: str
net_owed: str
overpaid: bool
class IncomeSummary(BaseModel):
gross_income: str
income_tax_withheld: str
ni_withheld: str
payslips: list[dict[str, Any]]
class TaxReportResponse(BaseModel):
tax_year: int
tax_year_display: str
profile: dict[str, Any] | None
income: IncomeSummary
income_tax: IncomeTaxSummary
ni: NISummary
cgt: CGTSummary
dividends: DividendSummary
summary: TaxReportSummary

View file

@ -35,6 +35,8 @@ class TransactionUpdate(BaseModel):
merchant: str | None = None
notes: str | None = None
tags: list[str] | None = None
is_recurring: bool | None = None
recurring_rule: dict | None = None
class TransactionFilter(BaseModel):
@ -71,6 +73,7 @@ class TransactionResponse(BaseModel):
notes: str | None
tags: list[str]
is_recurring: bool
recurring_rule: dict | None = None
attachment_refs: list[dict] = []
created_at: datetime
updated_at: datetime

View file

@ -85,6 +85,10 @@ async def register_user(db: AsyncSession, email: str, password: str, display_nam
)
db.add(user)
await db.flush()
from app.services.tax_service import seed_default_rates
await seed_default_rates(db, user.id)
return user

View file

@ -0,0 +1,194 @@
from __future__ import annotations
import re
import uuid
from collections import defaultdict
from datetime import date, datetime, timezone
from decimal import Decimal
from statistics import mean, stdev
from dateutil.relativedelta import relativedelta
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decrypt_field
from app.db.models.transaction import Transaction
MIN_OCCURRENCES = 2
# (label, min_days, max_days, tolerance_days)
_FREQUENCIES = [
("weekly", 6, 8, 2),
("fortnightly", 13, 15, 3),
("monthly", 26, 35, 5),
("quarterly", 85, 95, 10),
("yearly", 355, 375, 15),
]
_STRIP_PREFIXES = re.compile(
r"^(direct debit|standing order|faster payment|bacs|dd|so)\s+",
re.IGNORECASE,
)
_STRIP_REFS = re.compile(r"\s+\d{5,}$")
_STRIP_DATE_PATTERNS = re.compile(
r"\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b"
r"|\b\d{1,2}(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)\b",
re.IGNORECASE,
)
_COLLAPSE_SPACES = re.compile(r"\s{2,}")
def normalise_description(raw: str) -> str:
s = raw.lower().strip()
s = _STRIP_PREFIXES.sub("", s)
s = _STRIP_DATE_PATTERNS.sub("", s)
s = _STRIP_REFS.sub("", s)
s = _COLLAPSE_SPACES.sub(" ", s).strip()
return s
def classify_frequency(avg_days: float) -> tuple[str, int] | None:
"""Return (label, expected_days) or None if avg_days matches no known frequency."""
for label, lo, hi, _ in _FREQUENCIES:
if lo <= avg_days <= hi:
return label, round(avg_days)
return None
def _within_tolerance(intervals: list[int], avg: float, frequency: str) -> bool:
tolerance = next(t for label, _, _, t in _FREQUENCIES if label == frequency)
return all(abs(iv - avg) <= tolerance for iv in intervals)
def next_expected_date(last: date, frequency: str) -> date:
delta_map = {
"weekly": relativedelta(weeks=1),
"fortnightly": relativedelta(weeks=2),
"monthly": relativedelta(months=1),
"quarterly": relativedelta(months=3),
"yearly": relativedelta(years=1),
}
delta = delta_map[frequency]
result = last + delta
today = date.today()
while result < today:
result += delta
return result
def _confidence(intervals: list[int], expected: int) -> float:
if len(intervals) < 2:
return 1.0
try:
sd = stdev(intervals)
except Exception:
sd = 0.0
conf = 1.0 - (sd / expected) if expected > 0 else 0.0
return round(max(0.0, min(1.0, conf)), 4)
async def detect_recurring(db: AsyncSession, user_id: uuid.UUID) -> dict:
"""
Scan all transactions for a user, detect recurring patterns, and tag them.
Skips transactions where recurring_rule.manually_set == true.
Returns {"newly_tagged": int, "total_recurring": int}.
"""
result = await db.execute(
select(Transaction).where(
Transaction.user_id == user_id,
Transaction.deleted_at.is_(None),
)
)
transactions = result.scalars().all()
txn_map: dict[uuid.UUID, Transaction] = {t.id: t for t in transactions}
# Group by (normalised_description, exact_amount)
keyed: dict[tuple[str, Decimal], tuple[list[date], list[uuid.UUID]]] = defaultdict(
lambda: ([], [])
)
for txn in transactions:
try:
desc = decrypt_field(txn.description_enc) or ""
except Exception:
desc = ""
norm = normalise_description(desc)
amount = txn.amount.quantize(Decimal("0.01"))
dates_list, ids_list = keyed[(norm, amount)]
dates_list.append(txn.date)
ids_list.append(txn.id)
now = datetime.now(timezone.utc)
newly_tagged = 0
total_recurring = 0
matched_ids: set[uuid.UUID] = set()
for (norm_desc, amount), (dates_list, ids_list) in keyed.items():
if len(dates_list) < MIN_OCCURRENCES:
continue
paired = sorted(zip(dates_list, ids_list), key=lambda x: x[0])
sorted_dates = [p[0] for p in paired]
sorted_ids = [p[1] for p in paired]
intervals = [
(sorted_dates[i + 1] - sorted_dates[i]).days
for i in range(len(sorted_dates) - 1)
]
avg = mean(intervals)
freq_result = classify_frequency(avg)
if freq_result is None:
continue
frequency, expected_days = freq_result
if not _within_tolerance(intervals, avg, frequency):
continue
conf = _confidence(intervals, expected_days)
last_date = sorted_dates[-1]
next_date = next_expected_date(last_date, frequency)
for txn_id in sorted_ids:
txn = txn_map.get(txn_id)
if txn is None:
continue
matched_ids.add(txn_id)
rule = txn.recurring_rule or {}
if rule.get("manually_set"):
if txn.is_recurring:
total_recurring += 1
continue
was_recurring = txn.is_recurring
txn.is_recurring = True
txn.recurring_rule = {
"frequency": frequency,
"typical_amount": float(amount),
"next_expected": next_date.isoformat(),
"last_paid": last_date.isoformat(),
"confidence": conf,
"detected_at": now.isoformat(),
"manually_set": False,
}
txn.updated_at = now
if not was_recurring:
newly_tagged += 1
total_recurring += 1
# Un-tag previously auto-detected transactions whose pattern no longer matches
for txn in transactions:
if not txn.is_recurring:
continue
rule = txn.recurring_rule or {}
if rule.get("manually_set"):
continue
if txn.id not in matched_ids:
txn.is_recurring = False
txn.recurring_rule = None
txn.updated_at = now
await db.flush()
return {"newly_tagged": newly_tagged, "total_recurring": total_recurring}

View file

@ -0,0 +1,301 @@
"""
Pure UK tax calculation functions. Zero external dependencies no DB, no ORM.
Each function receives a pre-loaded `rates` dict so they are fully unit-testable.
Tax year convention: tax_year=N means 6 Apr (N-1) 5 Apr N.
"""
from __future__ import annotations
import re
from datetime import date
from decimal import ROUND_HALF_UP, Decimal
from typing import Any
# ---------------------------------------------------------------------------
# Tax year helpers
# ---------------------------------------------------------------------------
def tax_year_for_date(d: date) -> int:
"""Return tax_year int for a calendar date. tax_year=N = 6 Apr (N-1) → 5 Apr N."""
if (d.month, d.day) >= (4, 6):
return d.year + 1
return d.year
def tax_year_date_range(tax_year: int) -> tuple[date, date]:
"""Return (start_date, end_date) inclusive for the given tax year."""
return date(tax_year - 1, 4, 6), date(tax_year, 4, 5)
# ---------------------------------------------------------------------------
# Tax code parser
# ---------------------------------------------------------------------------
def parse_tax_code(code: str) -> dict[str, Any]:
"""Parse a UK tax code string.
Returns:
allowance annual personal allowance in £ (negative for K codes)
rate_override flat rate (0.01.0) if code fixes a single rate, else None
k_code True if K prefix (negative allowance)
no_tax True if NT code
"""
raw = code.strip().upper()
raw = re.sub(r"[/\s]?(W1|M1)$", "", raw)
if raw == "NT":
return {"allowance": Decimal("0"), "rate_override": Decimal("0"), "k_code": False, "no_tax": True}
if raw == "BR":
return {"allowance": Decimal("0"), "rate_override": Decimal("0.20"), "k_code": False, "no_tax": False}
if raw == "D0":
return {"allowance": Decimal("0"), "rate_override": Decimal("0.40"), "k_code": False, "no_tax": False}
if raw == "D1":
return {"allowance": Decimal("0"), "rate_override": Decimal("0.45"), "k_code": False, "no_tax": False}
if raw == "0T":
return {"allowance": Decimal("0"), "rate_override": None, "k_code": False, "no_tax": False}
k_match = re.fullmatch(r"K(\d+)", raw)
if k_match:
return {"allowance": -Decimal(k_match.group(1)) * 10, "rate_override": None, "k_code": True, "no_tax": False}
std_match = re.fullmatch(r"(\d+)[LMNTY]?", raw)
if std_match:
return {"allowance": Decimal(std_match.group(1)) * 10, "rate_override": None, "k_code": False, "no_tax": False}
# Unknown code — treat as 0T
return {"allowance": Decimal("0"), "rate_override": None, "k_code": False, "no_tax": False}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _apply_bands(amount: Decimal, bands: list[dict]) -> tuple[Decimal, list[dict]]:
total = Decimal("0")
breakdown = []
for band in bands:
band_from = Decimal(str(band["from"]))
band_to = Decimal(str(band["to"])) if band["to"] is not None else None
rate = Decimal(str(band["rate"]))
if amount <= band_from:
break
upper = min(amount, band_to) if band_to is not None else amount
taxable_in_band = upper - band_from
tax_in_band = (taxable_in_band * rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
total += tax_in_band
if taxable_in_band > 0:
breakdown.append({
"from": int(band_from),
"to": int(band_to) if band_to is not None else None,
"rate": float(rate),
"taxable": float(taxable_in_band),
"tax": float(tax_in_band),
})
return total, breakdown
def _personal_allowance_tapered(base_allowance: Decimal, gross_income: Decimal) -> Decimal:
"""Reduce PA by £1 per £2 over £100,000; floor at zero at £125,140."""
taper_threshold = Decimal("100000")
if gross_income <= taper_threshold:
return base_allowance
reduction = ((gross_income - taper_threshold) / 2).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
return max(Decimal("0"), base_allowance - reduction)
# ---------------------------------------------------------------------------
# Core calculation functions
# ---------------------------------------------------------------------------
def calculate_income_tax(
gross_income: Decimal,
tax_code: str,
rates: dict,
) -> dict[str, Any]:
"""Calculate income tax liability and remaining basic-rate band.
Bands are applied to GROSS income. The 0% band threshold is adjusted to match
the actual personal allowance from the tax code (with taper applied if applicable).
K codes add their amount to gross income before applying the standard bands.
Returns:
personal_allowance, taxable_income, liability, band_breakdown,
remaining_basic_rate_band (passed downstream to CGT/dividend calculations)
"""
parsed = parse_tax_code(tax_code)
bands = rates["income_tax"]["bands"]
# These are gross-income thresholds from the band definitions
pa_threshold = Decimal(str(next(b["to"] for b in bands if b["rate"] == 0.00)))
basic_rate_upper = Decimal(str(next(b["to"] for b in bands if b["rate"] == 0.20)))
if parsed["no_tax"]:
return {
"personal_allowance": pa_threshold,
"taxable_income": Decimal("0"),
"liability": Decimal("0"),
"band_breakdown": [],
"remaining_basic_rate_band": max(Decimal("0"), basic_rate_upper - gross_income),
}
if parsed["rate_override"] is not None:
liability = (gross_income * parsed["rate_override"]).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
return {
"personal_allowance": Decimal("0"),
"taxable_income": gross_income,
"liability": liability,
"band_breakdown": [{"rate": float(parsed["rate_override"]), "tax": float(liability)}],
"remaining_basic_rate_band": Decimal("0"),
}
if parsed["k_code"]:
# K codes: the K amount adds to taxable base; standard PA band still applies to effective gross.
# effective_gross is the "notional income" HMRC uses to calculate tax.
k_amount = abs(parsed["allowance"])
effective_gross = gross_income + k_amount
personal_allowance = Decimal("0") # K code replaces any standard PA grant
taxable_income = effective_gross # reported as the effective taxable base
liability, band_breakdown = _apply_bands(effective_gross, bands)
remaining_brb = max(Decimal("0"), basic_rate_upper - effective_gross)
else:
base_pa = parsed["allowance"]
personal_allowance = _personal_allowance_tapered(base_pa, gross_income) if base_pa > 0 else base_pa
# Adjust the 0% band to match the actual personal allowance, then apply to gross income.
adjusted_bands = [
{"from": 0, "to": float(personal_allowance), "rate": 0.00} if b["rate"] == 0.00 else b
for b in bands
]
taxable_income = max(Decimal("0"), gross_income - personal_allowance)
liability, band_breakdown = _apply_bands(gross_income, adjusted_bands)
remaining_brb = max(Decimal("0"), basic_rate_upper - gross_income)
return {
"personal_allowance": personal_allowance,
"taxable_income": taxable_income,
"liability": liability,
"band_breakdown": band_breakdown,
"remaining_basic_rate_band": remaining_brb,
}
def calculate_ni(gross_income: Decimal, rates: dict) -> dict[str, Any]:
"""Calculate primary Class 1 NI liability."""
bands = rates["ni"]["bands"]
liability, band_breakdown = _apply_bands(gross_income, bands)
return {"liability": liability, "band_breakdown": band_breakdown}
def calculate_cgt(
total_gain: Decimal,
remaining_basic_rate_band: Decimal,
rates: dict,
) -> dict[str, Any]:
"""Calculate CGT liability.
Gains within the remaining basic-rate band are taxed at basic_rate;
gains above it at higher_rate. Annual exempt amount applied first.
"""
cgt_rates = rates["cgt"]
exempt = Decimal(str(cgt_rates["exempt"]))
basic_rate = Decimal(str(cgt_rates["basic_rate"]))
higher_rate = Decimal(str(cgt_rates["higher_rate"]))
taxable_gain = max(Decimal("0"), total_gain - exempt)
if taxable_gain == 0:
return {
"gross_gain": total_gain,
"exempt": min(total_gain, exempt) if total_gain > 0 else Decimal("0"),
"taxable_gain": Decimal("0"),
"liability": Decimal("0"),
"band_breakdown": [],
}
basic_portion = min(taxable_gain, remaining_basic_rate_band)
higher_portion = taxable_gain - basic_portion
liability = (
(basic_portion * basic_rate) + (higher_portion * higher_rate)
).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
breakdown = []
if basic_portion > 0:
breakdown.append({"rate": float(basic_rate), "taxable": float(basic_portion),
"tax": float((basic_portion * basic_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
if higher_portion > 0:
breakdown.append({"rate": float(higher_rate), "taxable": float(higher_portion),
"tax": float((higher_portion * higher_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
return {
"gross_gain": total_gain,
"exempt": min(total_gain, exempt),
"taxable_gain": taxable_gain,
"liability": liability,
"band_breakdown": breakdown,
}
def calculate_dividend_tax(
total_dividends: Decimal,
remaining_basic_rate_band: Decimal,
rates: dict,
) -> dict[str, Any]:
"""Calculate dividend tax liability.
Dividend allowance applied first; taxable dividends are then slotted into
the remaining income bands to determine which rate applies.
"""
div_rates = rates["dividend"]
allowance = Decimal(str(div_rates["allowance"]))
basic_rate = Decimal(str(div_rates["basic_rate"]))
higher_rate = Decimal(str(div_rates["higher_rate"]))
additional_rate = Decimal(str(div_rates["additional_rate"]))
taxable_dividends = max(Decimal("0"), total_dividends - allowance)
if taxable_dividends == 0:
return {
"gross_dividends": total_dividends,
"allowance": min(total_dividends, allowance),
"taxable_dividends": Decimal("0"),
"liability": Decimal("0"),
"band_breakdown": [],
}
basic_portion = min(taxable_dividends, remaining_basic_rate_band)
remainder = taxable_dividends - basic_portion
higher_upper = Decimal(str(
next(b["to"] for b in rates["income_tax"]["bands"] if b["rate"] == 0.40)
))
higher_portion = min(remainder, higher_upper)
additional_portion = remainder - higher_portion
liability = (
(basic_portion * basic_rate)
+ (higher_portion * higher_rate)
+ (additional_portion * additional_rate)
).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
breakdown = []
if basic_portion > 0:
breakdown.append({"rate": float(basic_rate), "taxable": float(basic_portion),
"tax": float((basic_portion * basic_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
if higher_portion > 0:
breakdown.append({"rate": float(higher_rate), "taxable": float(higher_portion),
"tax": float((higher_portion * higher_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
if additional_portion > 0:
breakdown.append({"rate": float(additional_rate), "taxable": float(additional_portion),
"tax": float((additional_portion * additional_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))})
return {
"gross_dividends": total_dividends,
"allowance": min(total_dividends, allowance),
"taxable_dividends": taxable_dividends,
"liability": liability,
"band_breakdown": breakdown,
}

View file

@ -0,0 +1,744 @@
"""
UK tax service layer for MyMidas.
Pure calculation functions live in tax_calculations.py (no DB imports).
This module provides the DB-backed service layer: rate loading, CRUD, report builder.
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from typing import Any
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decrypt_field, encrypt_field
from app.services.tax_calculations import ( # noqa: F401 — re-exported for callers
calculate_cgt,
calculate_dividend_tax,
calculate_income_tax,
calculate_ni,
parse_tax_code,
tax_year_date_range,
tax_year_for_date,
)
# ---------------------------------------------------------------------------
# DB helpers
# ---------------------------------------------------------------------------
async def load_rates(db: AsyncSession, user_id: uuid.UUID, tax_year: int) -> dict:
"""Load and return rate config dict for a given user/year.
Returns {"income_tax": {...}, "ni": {...}, "cgt": {...}, "dividend": {...}}
Raises ValueError if any rate type is missing for the requested year.
"""
from app.db.models.tax import TaxRateConfig
result = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
)
)
rows = list(result.scalars())
if not rows:
raise ValueError(f"No tax rate config found for year {tax_year}. Please configure rates first.")
rates: dict = {}
for row in rows:
rates[row.rate_type] = row.config
required = {"income_tax", "ni", "cgt", "dividend"}
missing = required - rates.keys()
if missing:
raise ValueError(f"Incomplete tax rate config for year {tax_year}: missing {missing}")
return rates
async def seed_default_rates(db: AsyncSession, user_id: uuid.UUID) -> None:
"""Insert default 2025 and 2026 rate configs for a newly registered user."""
from app.db.models.tax import TaxRateConfig
now = datetime.now(timezone.utc)
income_tax_bands = {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.20},
{"from": 50270, "to": 125140, "rate": 0.40},
{"from": 125140, "to": None, "rate": 0.45},
]
}
ni_bands = {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.08},
{"from": 50270, "to": None, "rate": 0.02},
]
}
cgt = {"exempt": 3000, "basic_rate": 0.18, "higher_rate": 0.24}
dividend = {
"allowance": 500,
"basic_rate": 0.0875,
"higher_rate": 0.3375,
"additional_rate": 0.3935,
}
defaults = {
"income_tax": income_tax_bands,
"ni": ni_bands,
"cgt": cgt,
"dividend": dividend,
}
for tax_year in (2025, 2026):
for rate_type, config in defaults.items():
existing = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
TaxRateConfig.rate_type == rate_type,
)
)
if existing.scalar_one_or_none() is None:
db.add(TaxRateConfig(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
rate_type=rate_type,
config=config,
updated_at=now,
))
# ---------------------------------------------------------------------------
# Tax profile CRUD
# ---------------------------------------------------------------------------
async def get_tax_profile(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
):
from app.db.models.tax import TaxProfile
result = await db.execute(
select(TaxProfile).where(
TaxProfile.user_id == user_id,
TaxProfile.tax_year == tax_year,
)
)
return result.scalar_one_or_none()
async def upsert_tax_profile(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
tax_code: str,
employer_name: str | None,
is_cumulative: bool,
):
from app.db.models.tax import TaxProfile
from app.core.audit import write_audit
now = datetime.now(timezone.utc)
profile = await get_tax_profile(db, user_id, tax_year)
employer_enc = encrypt_field(employer_name) if employer_name else None
if profile is None:
profile = TaxProfile(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
tax_code=tax_code,
employer_name_enc=employer_enc,
is_cumulative=is_cumulative,
created_at=now,
updated_at=now,
)
db.add(profile)
action = "tax_profile_create"
else:
profile.tax_code = tax_code
profile.employer_name_enc = employer_enc
profile.is_cumulative = is_cumulative
profile.updated_at = now
action = "tax_profile_update"
await db.flush()
await write_audit(db, user_id=user_id, action=action,
resource_type="tax_profile", resource_id=profile.id)
return profile
def _profile_to_response(profile) -> dict:
from app.db.models.tax import TaxProfile
employer = None
if profile.employer_name_enc:
try:
employer = decrypt_field(profile.employer_name_enc)
except Exception:
employer = None
return {
"id": str(profile.id),
"tax_year": profile.tax_year,
"tax_code": profile.tax_code,
"employer_name": employer,
"is_cumulative": profile.is_cumulative,
"created_at": profile.created_at.isoformat(),
"updated_at": profile.updated_at.isoformat(),
}
# ---------------------------------------------------------------------------
# Payslip CRUD
# ---------------------------------------------------------------------------
async def list_payslips(db: AsyncSession, user_id: uuid.UUID, tax_year: int) -> list:
from app.db.models.tax import Payslip, TaxProfile
result = await db.execute(
select(Payslip)
.join(TaxProfile, Payslip.tax_profile_id == TaxProfile.id)
.where(
Payslip.user_id == user_id,
TaxProfile.tax_year == tax_year,
)
.order_by(Payslip.period_year, Payslip.period_month.nulls_last())
)
return list(result.scalars())
async def create_payslip(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
period_month: int | None,
period_year: int,
gross_pay: Decimal,
income_tax_withheld: Decimal,
ni_withheld: Decimal,
net_pay: Decimal,
is_p60: bool = False,
notes: str | None = None,
):
from app.db.models.tax import Payslip
from app.core.audit import write_audit
profile = await get_tax_profile(db, user_id, tax_year)
if profile is None:
raise ValueError(f"No tax profile for year {tax_year}. Create a profile first.")
now = datetime.now(timezone.utc)
notes_enc = encrypt_field(notes) if notes else None
payslip = Payslip(
id=uuid.uuid4(),
user_id=user_id,
tax_profile_id=profile.id,
period_month=period_month,
period_year=period_year,
gross_pay=gross_pay,
income_tax_withheld=income_tax_withheld,
ni_withheld=ni_withheld,
net_pay=net_pay,
is_p60=is_p60,
notes_enc=notes_enc,
created_at=now,
)
db.add(payslip)
await db.flush()
await write_audit(db, user_id=user_id, action="payslip_create",
resource_type="payslip", resource_id=payslip.id)
return payslip
async def update_payslip(
db: AsyncSession,
user_id: uuid.UUID,
payslip_id: uuid.UUID,
**kwargs,
):
from app.db.models.tax import Payslip
from app.core.audit import write_audit
result = await db.execute(
select(Payslip).where(Payslip.id == payslip_id, Payslip.user_id == user_id)
)
payslip = result.scalar_one_or_none()
if payslip is None:
raise ValueError("Payslip not found")
for field, value in kwargs.items():
if field == "notes":
payslip.notes_enc = encrypt_field(value) if value else None
else:
setattr(payslip, field, value)
await db.flush()
await write_audit(db, user_id=user_id, action="payslip_update",
resource_type="payslip", resource_id=payslip.id)
return payslip
async def delete_payslip(db: AsyncSession, user_id: uuid.UUID, payslip_id: uuid.UUID) -> None:
from app.db.models.tax import Payslip
from app.core.audit import write_audit
result = await db.execute(
select(Payslip).where(Payslip.id == payslip_id, Payslip.user_id == user_id)
)
payslip = result.scalar_one_or_none()
if payslip is None:
raise ValueError("Payslip not found")
await write_audit(db, user_id=user_id, action="payslip_delete",
resource_type="payslip", resource_id=payslip_id)
await db.delete(payslip)
await db.flush()
async def replace_with_p60(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
gross_pay: Decimal,
income_tax_withheld: Decimal,
ni_withheld: Decimal,
net_pay: Decimal,
) -> None:
"""Delete all existing payslips for the tax year and replace with a single P60."""
from app.db.models.tax import Payslip, TaxProfile
from app.core.audit import write_audit
profile = await get_tax_profile(db, user_id, tax_year)
if profile is None:
raise ValueError(f"No tax profile for year {tax_year}. Create a profile first.")
await db.execute(
delete(Payslip).where(
Payslip.user_id == user_id,
Payslip.tax_profile_id == profile.id,
)
)
now = datetime.now(timezone.utc)
p60 = Payslip(
id=uuid.uuid4(),
user_id=user_id,
tax_profile_id=profile.id,
period_month=None,
period_year=tax_year - 1, # P60 covers year ending 5 Apr tax_year; the employer year starts the prior calendar year
gross_pay=gross_pay,
income_tax_withheld=income_tax_withheld,
ni_withheld=ni_withheld,
net_pay=net_pay,
is_p60=True,
notes_enc=None,
created_at=now,
)
db.add(p60)
await db.flush()
await write_audit(db, user_id=user_id, action="payslip_p60_replace",
resource_type="payslip", resource_id=p60.id)
def _payslip_to_response(payslip) -> dict:
notes = None
if payslip.notes_enc:
try:
notes = decrypt_field(payslip.notes_enc)
except Exception:
notes = None
return {
"id": str(payslip.id),
"tax_profile_id": str(payslip.tax_profile_id),
"period_month": payslip.period_month,
"period_year": payslip.period_year,
"gross_pay": str(payslip.gross_pay),
"income_tax_withheld": str(payslip.income_tax_withheld),
"ni_withheld": str(payslip.ni_withheld),
"net_pay": str(payslip.net_pay),
"is_p60": payslip.is_p60,
"notes": notes,
"created_at": payslip.created_at.isoformat(),
}
# ---------------------------------------------------------------------------
# Manual CGT disposal CRUD
# ---------------------------------------------------------------------------
async def list_manual_disposals(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
) -> list:
from app.db.models.tax import ManualCGTDisposal
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.user_id == user_id,
ManualCGTDisposal.tax_year == tax_year,
).order_by(ManualCGTDisposal.disposal_date)
)
return list(result.scalars())
async def create_manual_disposal(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
disposal_date: date,
asset_description: str,
proceeds: Decimal,
cost_basis: Decimal,
notes: str | None = None,
):
from app.db.models.tax import ManualCGTDisposal
from app.core.audit import write_audit
now = datetime.now(timezone.utc)
disposal = ManualCGTDisposal(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
disposal_date=disposal_date,
asset_description_enc=encrypt_field(asset_description),
proceeds=proceeds,
cost_basis=cost_basis,
notes_enc=encrypt_field(notes) if notes else None,
created_at=now,
)
db.add(disposal)
await db.flush()
await write_audit(db, user_id=user_id, action="cgt_disposal_create",
resource_type="manual_cgt_disposal", resource_id=disposal.id)
return disposal
async def update_manual_disposal(
db: AsyncSession,
user_id: uuid.UUID,
disposal_id: uuid.UUID,
disposal_date: date,
asset_description: str,
proceeds: Decimal,
cost_basis: Decimal,
notes: str | None = None,
):
from app.db.models.tax import ManualCGTDisposal
from app.core.audit import write_audit
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.id == disposal_id,
ManualCGTDisposal.user_id == user_id,
)
)
disposal = result.scalar_one_or_none()
if disposal is None:
raise ValueError("Disposal not found")
disposal.disposal_date = disposal_date
disposal.asset_description_enc = encrypt_field(asset_description)
disposal.proceeds = proceeds
disposal.cost_basis = cost_basis
disposal.notes_enc = encrypt_field(notes) if notes else None
await db.flush()
await write_audit(db, user_id=user_id, action="cgt_disposal_update",
resource_type="manual_cgt_disposal", resource_id=disposal.id)
return disposal
async def delete_manual_disposal(
db: AsyncSession, user_id: uuid.UUID, disposal_id: uuid.UUID
) -> None:
from app.db.models.tax import ManualCGTDisposal
from app.core.audit import write_audit
result = await db.execute(
select(ManualCGTDisposal).where(
ManualCGTDisposal.id == disposal_id,
ManualCGTDisposal.user_id == user_id,
)
)
disposal = result.scalar_one_or_none()
if disposal is None:
raise ValueError("Disposal not found")
await write_audit(db, user_id=user_id, action="cgt_disposal_delete",
resource_type="manual_cgt_disposal", resource_id=disposal_id)
await db.delete(disposal)
await db.flush()
def _disposal_to_response(disposal) -> dict:
asset_desc = ""
try:
asset_desc = decrypt_field(disposal.asset_description_enc)
except Exception:
pass
notes = None
if disposal.notes_enc:
try:
notes = decrypt_field(disposal.notes_enc)
except Exception:
pass
gain_loss = disposal.proceeds - disposal.cost_basis
return {
"id": str(disposal.id),
"tax_year": disposal.tax_year,
"disposal_date": disposal.disposal_date.isoformat(),
"asset_description": asset_desc,
"proceeds": str(disposal.proceeds),
"cost_basis": str(disposal.cost_basis),
"gain_loss": str(gain_loss),
"notes": notes,
"created_at": disposal.created_at.isoformat(),
}
# ---------------------------------------------------------------------------
# Tax rate config CRUD
# ---------------------------------------------------------------------------
async def list_configured_years(db: AsyncSession, user_id: uuid.UUID) -> list[int]:
from app.db.models.tax import TaxRateConfig
result = await db.execute(
select(TaxRateConfig.tax_year)
.where(TaxRateConfig.user_id == user_id)
.distinct()
.order_by(TaxRateConfig.tax_year)
)
return [row[0] for row in result]
async def get_rate_config(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
) -> dict:
from app.db.models.tax import TaxRateConfig
result = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
)
)
rows = list(result.scalars())
if not rows:
raise ValueError(f"No rate config for year {tax_year}")
return {
"tax_year": tax_year,
"rates": {row.rate_type: row.config for row in rows},
"updated_at": max(row.updated_at for row in rows).isoformat(),
}
async def upsert_rate_config(
db: AsyncSession,
user_id: uuid.UUID,
tax_year: int,
rates: dict,
) -> dict:
from app.db.models.tax import TaxRateConfig
from app.core.audit import write_audit
now = datetime.now(timezone.utc)
for rate_type, config in rates.items():
result = await db.execute(
select(TaxRateConfig).where(
TaxRateConfig.user_id == user_id,
TaxRateConfig.tax_year == tax_year,
TaxRateConfig.rate_type == rate_type,
)
)
row = result.scalar_one_or_none()
if row is None:
db.add(TaxRateConfig(
id=uuid.uuid4(),
user_id=user_id,
tax_year=tax_year,
rate_type=rate_type,
config=config,
updated_at=now,
))
else:
row.config = config
row.updated_at = now
await db.flush()
await write_audit(db, user_id=user_id, action="tax_rate_config_update",
resource_type="tax_rate_config",
resource_id=None,
metadata={"tax_year": tax_year})
return await get_rate_config(db, user_id, tax_year)
# ---------------------------------------------------------------------------
# Report builder
# ---------------------------------------------------------------------------
async def build_tax_report(
db: AsyncSession, user_id: uuid.UUID, tax_year: int
) -> dict[str, Any]:
"""Build the full tax report for a given year.
Steps:
1. Load rates
2. Load profile + payslip totals
3. Load investment sell disposals within the tax year
4. Load investment dividend transactions within the tax year
5. Load manual CGT disposals
6. Run calculations: income tax NI CGT dividend tax
7. Return full report dict
"""
from app.db.models.tax import ManualCGTDisposal, Payslip, TaxProfile
from app.db.models.investment_transaction import InvestmentTransaction
from app.db.models.investment_holding import InvestmentHolding
from app.db.models.asset import Asset
rates = await load_rates(db, user_id, tax_year)
start_date, end_date = tax_year_date_range(tax_year)
# ---- Profile ----
profile = await get_tax_profile(db, user_id, tax_year)
tax_code = profile.tax_code if profile else "1257L"
profile_data = _profile_to_response(profile) if profile else None
# ---- Payslip totals ----
payslips = await list_payslips(db, user_id, tax_year)
gross_income = sum((Decimal(str(p.gross_pay)) for p in payslips), Decimal("0"))
income_tax_withheld = sum((Decimal(str(p.income_tax_withheld)) for p in payslips), Decimal("0"))
ni_withheld = sum((Decimal(str(p.ni_withheld)) for p in payslips), Decimal("0"))
payslip_rows = [_payslip_to_response(p) for p in payslips]
# ---- Investment sell disposals ----
inv_disposals_result = await db.execute(
select(InvestmentTransaction, InvestmentHolding, Asset)
.join(InvestmentHolding, InvestmentTransaction.holding_id == InvestmentHolding.id)
.join(Asset, InvestmentHolding.asset_id == Asset.id)
.where(
InvestmentTransaction.user_id == user_id,
InvestmentTransaction.type == "sell",
InvestmentTransaction.date >= start_date,
InvestmentTransaction.date <= end_date,
)
.order_by(InvestmentTransaction.date)
)
inv_disposal_rows = []
total_inv_gain = Decimal("0")
for inv_txn, holding, asset in inv_disposals_result:
proceeds = Decimal(str(inv_txn.total_amount))
cost = Decimal(str(inv_txn.quantity)) * Decimal(str(holding.avg_cost_basis))
gain = proceeds - cost - Decimal(str(inv_txn.fees))
total_inv_gain += gain
inv_disposal_rows.append({
"date": inv_txn.date.isoformat(),
"asset": asset.name,
"symbol": asset.symbol,
"quantity": str(inv_txn.quantity),
"proceeds": str(proceeds),
"cost_basis": str(cost),
"fees": str(inv_txn.fees),
"gain_loss": str(gain),
})
# ---- Manual CGT disposals ----
manual_disposals = await list_manual_disposals(db, user_id, tax_year)
manual_disposal_rows = [_disposal_to_response(d) for d in manual_disposals]
total_manual_gain = sum(
(Decimal(str(d.proceeds)) - Decimal(str(d.cost_basis)) for d in manual_disposals),
Decimal("0"),
)
total_cgt_gain = total_inv_gain + total_manual_gain
# ---- Investment dividends ----
div_result = await db.execute(
select(InvestmentTransaction, InvestmentHolding, Asset)
.join(InvestmentHolding, InvestmentTransaction.holding_id == InvestmentHolding.id)
.join(Asset, InvestmentHolding.asset_id == Asset.id)
.where(
InvestmentTransaction.user_id == user_id,
InvestmentTransaction.type == "dividend",
InvestmentTransaction.date >= start_date,
InvestmentTransaction.date <= end_date,
)
.order_by(InvestmentTransaction.date)
)
dividend_rows = []
total_dividends = Decimal("0")
for inv_txn, holding, asset in div_result:
amount = Decimal(str(inv_txn.total_amount))
total_dividends += amount
dividend_rows.append({
"date": inv_txn.date.isoformat(),
"asset": asset.name,
"symbol": asset.symbol,
"amount": str(amount),
})
# ---- Calculations ----
income_tax_result = calculate_income_tax(gross_income, tax_code, rates)
ni_result = calculate_ni(gross_income, rates)
cgt_result = calculate_cgt(
total_cgt_gain,
income_tax_result["remaining_basic_rate_band"],
rates,
)
dividend_result = calculate_dividend_tax(
total_dividends,
income_tax_result["remaining_basic_rate_band"],
rates,
)
# ---- Totals ----
total_liability = (
income_tax_result["liability"]
+ ni_result["liability"]
+ cgt_result["liability"]
+ dividend_result["liability"]
)
total_withheld = income_tax_withheld + ni_withheld
net_owed = total_liability - total_withheld
return {
"tax_year": tax_year,
"tax_year_display": f"{tax_year - 1}/{str(tax_year)[2:]}",
"profile": profile_data,
"income": {
"gross_income": str(gross_income),
"income_tax_withheld": str(income_tax_withheld),
"ni_withheld": str(ni_withheld),
"payslips": payslip_rows,
},
"income_tax": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in income_tax_result.items()
if k != "remaining_basic_rate_band"},
"withheld": str(income_tax_withheld),
"owed": str(income_tax_result["liability"] - income_tax_withheld),
},
"ni": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in ni_result.items()},
"withheld": str(ni_withheld),
"owed": str(ni_result["liability"] - ni_withheld),
},
"cgt": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in cgt_result.items()},
"investment_disposals": inv_disposal_rows,
"manual_disposals": manual_disposal_rows,
"total_gain": str(total_cgt_gain),
},
"dividends": {
**{k: str(v) if isinstance(v, Decimal) else v
for k, v in dividend_result.items()},
"dividend_transactions": dividend_rows,
},
"summary": {
"total_liability": str(total_liability),
"total_withheld": str(total_withheld),
"net_owed": str(net_owed),
"overpaid": net_owed < 0,
},
}

View file

@ -47,6 +47,7 @@ def _to_response(t: Transaction) -> dict:
"notes": _dec(t.notes_enc),
"tags": t.tags or [],
"is_recurring": t.is_recurring,
"recurring_rule": t.recurring_rule,
"attachment_refs": t.attachment_refs or [],
"created_at": t.created_at,
"updated_at": t.updated_at,
@ -221,6 +222,10 @@ async def update_transaction(
txn.notes_enc = _enc(data.notes)
if data.tags is not None:
txn.tags = data.tags
if data.is_recurring is not None:
txn.is_recurring = data.is_recurring
if data.recurring_rule is not None:
txn.recurring_rule = data.recurring_rule
txn.updated_at = now
await db.flush()
@ -307,5 +312,7 @@ async def import_csv(
await db.flush()
if imported > 0:
await recalculate_balance(db, account_id)
from app.services.recurring_service import detect_recurring
await detect_recurring(db, user_id)
return {"imported": imported, "skipped": skipped}

View file

View file

@ -0,0 +1,9 @@
"""
Conftest for backend tests.
Pure calculation tests (test_tax_calculations.py) import from tax_calculations.py
which has no external dependencies, so no stubs are needed.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

View file

@ -0,0 +1,311 @@
"""
Unit tests for the pure tax calculation functions.
Reference figures verified against HMRC's tax calculator and HMRC guidance.
All tests use the 2025/26 frozen rate structure (same as seed data).
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from decimal import Decimal
from datetime import date
import pytest
from app.services.tax_calculations import (
tax_year_for_date,
parse_tax_code,
calculate_income_tax,
calculate_ni,
calculate_cgt,
calculate_dividend_tax,
)
# ---------------------------------------------------------------------------
# Shared rate fixture (mirrors the seeded data)
# ---------------------------------------------------------------------------
RATES = {
"income_tax": {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.20},
{"from": 50270, "to": 125140, "rate": 0.40},
{"from": 125140, "to": None, "rate": 0.45},
]
},
"ni": {
"bands": [
{"from": 0, "to": 12570, "rate": 0.00},
{"from": 12570, "to": 50270, "rate": 0.08},
{"from": 50270, "to": None, "rate": 0.02},
]
},
"cgt": {"exempt": 3000, "basic_rate": 0.18, "higher_rate": 0.24},
"dividend": {
"allowance": 500,
"basic_rate": 0.0875,
"higher_rate": 0.3375,
"additional_rate": 0.3935,
},
}
# ---------------------------------------------------------------------------
# tax_year_for_date
# ---------------------------------------------------------------------------
class TestTaxYearForDate:
def test_before_april_6(self):
assert tax_year_for_date(date(2025, 4, 5)) == 2025
def test_on_april_6(self):
assert tax_year_for_date(date(2025, 4, 6)) == 2026
def test_mid_year(self):
assert tax_year_for_date(date(2025, 10, 1)) == 2026
def test_january(self):
assert tax_year_for_date(date(2026, 1, 15)) == 2026
def test_april_5_boundary(self):
# 5 April 2024 → tax_year 2024
assert tax_year_for_date(date(2024, 4, 5)) == 2024
def test_april_6_boundary(self):
# 6 April 2024 → tax_year 2025
assert tax_year_for_date(date(2024, 4, 6)) == 2025
# ---------------------------------------------------------------------------
# parse_tax_code
# ---------------------------------------------------------------------------
class TestParseTaxCode:
def test_standard_1257l(self):
r = parse_tax_code("1257L")
assert r["allowance"] == Decimal("12570")
assert r["rate_override"] is None
assert r["k_code"] is False
assert r["no_tax"] is False
def test_standard_1257m(self):
assert parse_tax_code("1257M")["allowance"] == Decimal("12570")
def test_standard_1257n(self):
assert parse_tax_code("1257N")["allowance"] == Decimal("12570")
def test_br(self):
r = parse_tax_code("BR")
assert r["allowance"] == Decimal("0")
assert r["rate_override"] == Decimal("0.20")
def test_d0(self):
r = parse_tax_code("D0")
assert r["rate_override"] == Decimal("0.40")
def test_d1(self):
r = parse_tax_code("D1")
assert r["rate_override"] == Decimal("0.45")
def test_nt(self):
r = parse_tax_code("NT")
assert r["no_tax"] is True
def test_0t(self):
r = parse_tax_code("0T")
assert r["allowance"] == Decimal("0")
assert r["rate_override"] is None
def test_k_code(self):
r = parse_tax_code("K100")
assert r["allowance"] == Decimal("-1000")
assert r["k_code"] is True
def test_k_code_large(self):
r = parse_tax_code("K497")
assert r["allowance"] == Decimal("-4970")
def test_w1_suffix_stripped(self):
r = parse_tax_code("1257L W1")
assert r["allowance"] == Decimal("12570")
def test_m1_suffix_stripped(self):
r = parse_tax_code("1257L/M1")
assert r["allowance"] == Decimal("12570")
def test_lowercase(self):
r = parse_tax_code("1257l")
assert r["allowance"] == Decimal("12570")
# ---------------------------------------------------------------------------
# calculate_income_tax
# ---------------------------------------------------------------------------
class TestCalculateIncomeTax:
def test_below_personal_allowance(self):
r = calculate_income_tax(Decimal("10000"), "1257L", RATES)
assert r["liability"] == Decimal("0")
assert r["taxable_income"] == Decimal("0")
assert r["personal_allowance"] == Decimal("12570")
def test_at_personal_allowance_boundary(self):
r = calculate_income_tax(Decimal("12570"), "1257L", RATES)
assert r["liability"] == Decimal("0")
def test_basic_rate_salary_30k(self):
# £30,000 gross — taxable = £17,430, basic rate 20%
r = calculate_income_tax(Decimal("30000"), "1257L", RATES)
assert r["taxable_income"] == Decimal("17430")
assert r["liability"] == Decimal("3486.00")
def test_at_higher_rate_threshold(self):
# £50,270 gross — exactly at the basic rate upper; taxable = £37,700
r = calculate_income_tax(Decimal("50270"), "1257L", RATES)
assert r["taxable_income"] == Decimal("37700")
# 20% on (50270 - 12570) = 37700
expected = (Decimal("37700") * Decimal("0.20")).quantize(Decimal("0.01"))
assert r["liability"] == expected
def test_higher_rate_taxpayer_60k(self):
# £60,000 gross, 1257L
# 20% on (50270 - 12570) = 37700 → £7,540
# 40% on (60000 - 50270) = 9730 → £3,892
r = calculate_income_tax(Decimal("60000"), "1257L", RATES)
assert r["taxable_income"] == Decimal("47430") # 60000 - 12570
expected_basic = (Decimal("37700") * Decimal("0.20")).quantize(Decimal("0.01"))
expected_higher = (Decimal("9730") * Decimal("0.40")).quantize(Decimal("0.01"))
assert r["liability"] == expected_basic + expected_higher
def test_personal_allowance_taper_110k(self):
# £110,000 — allowance tapered by £5,000 → £7,570
r = calculate_income_tax(Decimal("110000"), "1257L", RATES)
assert r["personal_allowance"] == Decimal("7570")
def test_personal_allowance_taper_125140(self):
# At £125,140 allowance tapers to zero
r = calculate_income_tax(Decimal("125140"), "1257L", RATES)
assert r["personal_allowance"] == Decimal("0")
def test_personal_allowance_above_125140(self):
# Above £125,140 allowance stays at zero
r = calculate_income_tax(Decimal("150000"), "1257L", RATES)
assert r["personal_allowance"] == Decimal("0")
def test_br_code(self):
r = calculate_income_tax(Decimal("30000"), "BR", RATES)
assert r["liability"] == (Decimal("30000") * Decimal("0.20")).quantize(Decimal("0.01"))
def test_d0_code(self):
r = calculate_income_tax(Decimal("30000"), "D0", RATES)
assert r["liability"] == (Decimal("30000") * Decimal("0.40")).quantize(Decimal("0.01"))
def test_nt_code(self):
r = calculate_income_tax(Decimal("100000"), "NT", RATES)
assert r["liability"] == Decimal("0")
def test_k_code_increases_taxable(self):
# K100 = -£1000 allowance → taxable = gross + £1000
r = calculate_income_tax(Decimal("30000"), "K100", RATES)
assert r["taxable_income"] == Decimal("31000")
def test_remaining_basic_rate_band_basic_taxpayer(self):
# £30k gross → remaining = 50270 - 30000 = 20270
r = calculate_income_tax(Decimal("30000"), "1257L", RATES)
assert r["remaining_basic_rate_band"] == Decimal("20270")
def test_remaining_basic_rate_band_higher_taxpayer(self):
# £60k gross > £50270 → band exhausted
r = calculate_income_tax(Decimal("60000"), "1257L", RATES)
assert r["remaining_basic_rate_band"] == Decimal("0")
# ---------------------------------------------------------------------------
# calculate_ni
# ---------------------------------------------------------------------------
class TestCalculateNI:
def test_below_threshold(self):
r = calculate_ni(Decimal("12570"), RATES)
assert r["liability"] == Decimal("0")
def test_basic_rate_salary_30k(self):
# NI on (30000 - 12570) = £17,430 at 8% = £1,394.40
r = calculate_ni(Decimal("30000"), RATES)
assert r["liability"] == Decimal("1394.40")
def test_above_upper_earnings_limit(self):
# NI on (50270 - 12570) = £37,700 at 8% = £3,016 + (60000 - 50270) = £9,730 at 2% = £194.60
r = calculate_ni(Decimal("60000"), RATES)
assert r["liability"] == Decimal("3016.00") + Decimal("194.60")
# ---------------------------------------------------------------------------
# calculate_cgt
# ---------------------------------------------------------------------------
class TestCalculateCGT:
def test_below_exempt(self):
r = calculate_cgt(Decimal("1000"), Decimal("20000"), RATES)
assert r["liability"] == Decimal("0")
assert r["taxable_gain"] == Decimal("0")
def test_exactly_exempt(self):
r = calculate_cgt(Decimal("3000"), Decimal("20000"), RATES)
assert r["liability"] == Decimal("0")
def test_basic_rate_taxpayer(self):
# Gain £10,000 — exempt £3,000 — taxable £7,000 all at basic rate 18%
r = calculate_cgt(Decimal("10000"), Decimal("20000"), RATES)
assert r["taxable_gain"] == Decimal("7000")
assert r["liability"] == (Decimal("7000") * Decimal("0.18")).quantize(Decimal("0.01"))
def test_higher_rate_taxpayer(self):
# remaining_basic_rate_band = 0 → all at higher rate 24%
r = calculate_cgt(Decimal("10000"), Decimal("0"), RATES)
assert r["liability"] == (Decimal("7000") * Decimal("0.24")).quantize(Decimal("0.01"))
def test_split_basic_higher(self):
# taxable gain £7,000; remaining_brb £4,000 → £4k at 18%, £3k at 24%
r = calculate_cgt(Decimal("10000"), Decimal("4000"), RATES)
expected = (
(Decimal("4000") * Decimal("0.18")) +
(Decimal("3000") * Decimal("0.24"))
).quantize(Decimal("0.01"))
assert r["liability"] == expected
def test_negative_gain_no_tax(self):
r = calculate_cgt(Decimal("-5000"), Decimal("20000"), RATES)
assert r["liability"] == Decimal("0")
assert r["taxable_gain"] == Decimal("0")
# ---------------------------------------------------------------------------
# calculate_dividend_tax
# ---------------------------------------------------------------------------
class TestCalculateDividendTax:
def test_within_allowance(self):
r = calculate_dividend_tax(Decimal("400"), Decimal("20000"), RATES)
assert r["liability"] == Decimal("0")
def test_exactly_allowance(self):
r = calculate_dividend_tax(Decimal("500"), Decimal("20000"), RATES)
assert r["liability"] == Decimal("0")
def test_basic_rate_band(self):
# £1,500 dividends — allowance £500 — taxable £1,000 at basic 8.75%
r = calculate_dividend_tax(Decimal("1500"), Decimal("20000"), RATES)
assert r["taxable_dividends"] == Decimal("1000")
assert r["liability"] == (Decimal("1000") * Decimal("0.0875")).quantize(Decimal("0.01"))
def test_higher_rate_band(self):
# Remaining basic = 0 → taxable £1,000 at higher 33.75%
r = calculate_dividend_tax(Decimal("1500"), Decimal("0"), RATES)
assert r["liability"] == (Decimal("1000") * Decimal("0.3375")).quantize(Decimal("0.01"))
def test_no_dividends(self):
r = calculate_dividend_tax(Decimal("0"), Decimal("20000"), RATES)
assert r["liability"] == Decimal("0")

View file

@ -0,0 +1,262 @@
"""
Schema round-trip tests for tax.py Pydantic models.
Verifies that each schema accepts valid data, rejects invalid data,
and that the nested TaxReportResponse correctly validates the shape
returned by build_tax_report().
"""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import uuid
from datetime import date
from decimal import Decimal
import pytest
from pydantic import ValidationError
from app.schemas.tax import (
ManualDisposalCreate,
ManualDisposalUpdate,
P60Entry,
PayslipCreate,
PayslipUpdate,
TaxProfileCreate,
TaxRateConfigUpdate,
TaxReportResponse,
)
# ---------------------------------------------------------------------------
# TaxRateConfigUpdate
# ---------------------------------------------------------------------------
class TestTaxRateConfigUpdate:
def test_partial_update_accepted(self):
u = TaxRateConfigUpdate(cgt={"exempt": 3000, "basic_rate": 0.18, "higher_rate": 0.24})
assert u.cgt["exempt"] == 3000
assert u.income_tax is None
def test_all_none_valid(self):
u = TaxRateConfigUpdate()
assert u.income_tax is None
assert u.ni is None
# ---------------------------------------------------------------------------
# TaxProfileCreate
# ---------------------------------------------------------------------------
class TestTaxProfileCreate:
def test_defaults(self):
p = TaxProfileCreate()
assert p.tax_code == "1257L"
assert p.is_cumulative is True
assert p.employer_name is None
def test_custom_values(self):
p = TaxProfileCreate(tax_code="BR", employer_name="Acme Ltd", is_cumulative=False)
assert p.tax_code == "BR"
assert p.employer_name == "Acme Ltd"
def test_tax_code_too_long(self):
with pytest.raises(ValidationError):
TaxProfileCreate(tax_code="X" * 21)
# ---------------------------------------------------------------------------
# PayslipCreate
# ---------------------------------------------------------------------------
class TestPayslipCreate:
def test_valid(self):
p = PayslipCreate(
period_month=4,
period_year=2024,
gross_pay=Decimal("3000.00"),
income_tax_withheld=Decimal("286.00"),
ni_withheld=Decimal("220.00"),
net_pay=Decimal("2494.00"),
)
assert p.period_month == 4
assert p.gross_pay == Decimal("3000.00")
def test_invalid_month(self):
with pytest.raises(ValidationError):
PayslipCreate(
period_month=13,
period_year=2024,
gross_pay=Decimal("3000"),
income_tax_withheld=Decimal("286"),
ni_withheld=Decimal("220"),
net_pay=Decimal("2494"),
)
def test_negative_gross_rejected(self):
with pytest.raises(ValidationError):
PayslipCreate(
period_month=4,
period_year=2024,
gross_pay=Decimal("-1"),
income_tax_withheld=Decimal("0"),
ni_withheld=Decimal("0"),
net_pay=Decimal("0"),
)
def test_p60_no_month(self):
p = PayslipCreate(
period_month=None,
period_year=2024,
gross_pay=Decimal("36000"),
income_tax_withheld=Decimal("4686"),
ni_withheld=Decimal("2394"),
net_pay=Decimal("28920"),
)
assert p.period_month is None
# ---------------------------------------------------------------------------
# P60Entry
# ---------------------------------------------------------------------------
class TestP60Entry:
def test_valid(self):
e = P60Entry(
gross_pay=Decimal("36000"),
income_tax_withheld=Decimal("4686"),
ni_withheld=Decimal("2394"),
net_pay=Decimal("28920"),
)
assert e.gross_pay == Decimal("36000")
def test_negative_rejected(self):
with pytest.raises(ValidationError):
P60Entry(
gross_pay=Decimal("-1"),
income_tax_withheld=Decimal("0"),
ni_withheld=Decimal("0"),
net_pay=Decimal("0"),
)
# ---------------------------------------------------------------------------
# ManualDisposalCreate
# ---------------------------------------------------------------------------
class TestManualDisposalCreate:
def test_valid(self):
d = ManualDisposalCreate(
disposal_date=date(2025, 1, 15),
asset_description="Rental property",
proceeds=Decimal("250000"),
cost_basis=Decimal("200000"),
)
assert d.proceeds == Decimal("250000")
assert d.notes is None
def test_empty_description_rejected(self):
with pytest.raises(ValidationError):
ManualDisposalCreate(
disposal_date=date(2025, 1, 15),
asset_description="",
proceeds=Decimal("1000"),
cost_basis=Decimal("500"),
)
def test_negative_proceeds_rejected(self):
with pytest.raises(ValidationError):
ManualDisposalCreate(
disposal_date=date(2025, 1, 15),
asset_description="Something",
proceeds=Decimal("-1"),
cost_basis=Decimal("500"),
)
# ---------------------------------------------------------------------------
# TaxReportResponse — validates the full nested report shape
# ---------------------------------------------------------------------------
SAMPLE_REPORT = {
"tax_year": 2025,
"tax_year_display": "2024/25",
"profile": {
"id": str(uuid.uuid4()),
"tax_year": 2025,
"tax_code": "1257L",
"employer_name": "Acme Ltd",
"is_cumulative": True,
"created_at": "2025-01-01T00:00:00+00:00",
"updated_at": "2025-01-01T00:00:00+00:00",
},
"income": {
"gross_income": "45000.00",
"income_tax_withheld": "6486.00",
"ni_withheld": "2634.00",
"payslips": [],
},
"income_tax": {
"personal_allowance": "12570.00",
"taxable_income": "32430.00",
"liability": "6486.00",
"band_breakdown": [{"rate": 0.20, "taxable": 32430.0, "tax": 6486.0}],
"withheld": "6486.00",
"owed": "0.00",
},
"ni": {
"liability": "2634.00",
"band_breakdown": [{"rate": 0.08, "taxable": 32430.0, "tax": 2594.4}],
"withheld": "2634.00",
"owed": "0.00",
},
"cgt": {
"gross_gain": "0.00",
"exempt": "0.00",
"taxable_gain": "0.00",
"liability": "0.00",
"band_breakdown": [],
"investment_disposals": [],
"manual_disposals": [],
"total_gain": "0.00",
},
"dividends": {
"gross_dividends": "0.00",
"allowance": "0.00",
"taxable_dividends": "0.00",
"liability": "0.00",
"band_breakdown": [],
"dividend_transactions": [],
},
"summary": {
"total_liability": "9120.00",
"total_withheld": "9120.00",
"net_owed": "0.00",
"overpaid": False,
},
}
class TestTaxReportResponse:
def test_valid_report_parses(self):
r = TaxReportResponse(**SAMPLE_REPORT)
assert r.tax_year == 2025
assert r.tax_year_display == "2024/25"
assert r.summary.net_owed == "0.00"
assert r.summary.overpaid is False
def test_no_profile(self):
report = {**SAMPLE_REPORT, "profile": None}
r = TaxReportResponse(**report)
assert r.profile is None
def test_missing_summary_field_rejected(self):
bad_summary = {**SAMPLE_REPORT["summary"]}
del bad_summary["overpaid"]
with pytest.raises(ValidationError):
TaxReportResponse(**{**SAMPLE_REPORT, "summary": bad_summary})
def test_missing_income_field_rejected(self):
bad_income = {**SAMPLE_REPORT["income"]}
del bad_income["gross_income"]
with pytest.raises(ValidationError):
TaxReportResponse(**{**SAMPLE_REPORT, "income": bad_income})