from __future__ import annotations import re import uuid from collections import defaultdict from datetime import date, datetime, timezone from decimal import Decimal from statistics import mean, stdev from dateutil.relativedelta import relativedelta from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import decrypt_field from app.db.models.transaction import Transaction MIN_OCCURRENCES = 2 # (label, min_days, max_days, tolerance_days) _FREQUENCIES = [ ("weekly", 6, 8, 2), ("fortnightly", 13, 15, 3), ("monthly", 26, 35, 5), ("quarterly", 85, 95, 10), ("yearly", 355, 375, 15), ] _STRIP_PREFIXES = re.compile( r"^(direct debit|standing order|faster payment|bacs|dd|so)\s+", re.IGNORECASE, ) _STRIP_REFS = re.compile(r"\s+\d{5,}$") _STRIP_DATE_PATTERNS = re.compile( r"\b\d{1,2}[/-]\d{1,2}([/-]\d{2,4})?\b" r"|\b\d{1,2}(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)\b", re.IGNORECASE, ) _COLLAPSE_SPACES = re.compile(r"\s{2,}") def normalise_description(raw: str) -> str: s = raw.lower().strip() s = _STRIP_PREFIXES.sub("", s) s = _STRIP_DATE_PATTERNS.sub("", s) s = _STRIP_REFS.sub("", s) s = _COLLAPSE_SPACES.sub(" ", s).strip() return s def classify_frequency(avg_days: float) -> tuple[str, int] | None: """Return (label, expected_days) or None if avg_days matches no known frequency.""" for label, lo, hi, _ in _FREQUENCIES: if lo <= avg_days <= hi: return label, round(avg_days) return None def _within_tolerance(intervals: list[int], avg: float, frequency: str) -> bool: tolerance = next(t for label, _, _, t in _FREQUENCIES if label == frequency) return all(abs(iv - avg) <= tolerance for iv in intervals) def next_expected_date(last: date, frequency: str) -> date: delta_map = { "weekly": relativedelta(weeks=1), "fortnightly": relativedelta(weeks=2), "monthly": relativedelta(months=1), "quarterly": relativedelta(months=3), "yearly": relativedelta(years=1), } delta = delta_map[frequency] result = last + delta today = date.today() while result < today: result += delta return result def _confidence(intervals: list[int], expected: int) -> float: if len(intervals) < 2: return 1.0 try: sd = stdev(intervals) except Exception: sd = 0.0 conf = 1.0 - (sd / expected) if expected > 0 else 0.0 return round(max(0.0, min(1.0, conf)), 4) async def detect_recurring(db: AsyncSession, user_id: uuid.UUID) -> dict: """ Scan all transactions for a user, detect recurring patterns, and tag them. Skips transactions where recurring_rule.manually_set == true. Returns {"newly_tagged": int, "total_recurring": int}. """ result = await db.execute( select(Transaction).where( Transaction.user_id == user_id, Transaction.deleted_at.is_(None), ) ) transactions = result.scalars().all() txn_map: dict[uuid.UUID, Transaction] = {t.id: t for t in transactions} # Group by (normalised_description, exact_amount) keyed: dict[tuple[str, Decimal], tuple[list[date], list[uuid.UUID]]] = defaultdict( lambda: ([], []) ) for txn in transactions: try: desc = decrypt_field(txn.description_enc) or "" except Exception: desc = "" norm = normalise_description(desc) amount = txn.amount.quantize(Decimal("0.01")) dates_list, ids_list = keyed[(norm, amount)] dates_list.append(txn.date) ids_list.append(txn.id) now = datetime.now(timezone.utc) newly_tagged = 0 total_recurring = 0 matched_ids: set[uuid.UUID] = set() for (norm_desc, amount), (dates_list, ids_list) in keyed.items(): if len(dates_list) < MIN_OCCURRENCES: continue paired = sorted(zip(dates_list, ids_list), key=lambda x: x[0]) sorted_dates = [p[0] for p in paired] sorted_ids = [p[1] for p in paired] intervals = [ (sorted_dates[i + 1] - sorted_dates[i]).days for i in range(len(sorted_dates) - 1) ] avg = mean(intervals) freq_result = classify_frequency(avg) if freq_result is None: continue frequency, expected_days = freq_result if not _within_tolerance(intervals, avg, frequency): continue conf = _confidence(intervals, expected_days) last_date = sorted_dates[-1] next_date = next_expected_date(last_date, frequency) for txn_id in sorted_ids: txn = txn_map.get(txn_id) if txn is None: continue matched_ids.add(txn_id) rule = txn.recurring_rule or {} if rule.get("manually_set"): if txn.is_recurring: total_recurring += 1 continue was_recurring = txn.is_recurring txn.is_recurring = True txn.recurring_rule = { "frequency": frequency, "typical_amount": float(amount), "next_expected": next_date.isoformat(), "last_paid": last_date.isoformat(), "confidence": conf, "detected_at": now.isoformat(), "manually_set": False, } txn.updated_at = now if not was_recurring: newly_tagged += 1 total_recurring += 1 # Un-tag previously auto-detected transactions whose pattern no longer matches for txn in transactions: if not txn.is_recurring: continue rule = txn.recurring_rule or {} if rule.get("manually_set"): continue if txn.id not in matched_ids: txn.is_recurring = False txn.recurring_rule = None txn.updated_at = now await db.flush() return {"newly_tagged": newly_tagged, "total_recurring": total_recurring}