diff --git a/backend/app/ml/spending_forecast.py b/backend/app/ml/spending_forecast.py index 68d186b..d42a0b1 100644 --- a/backend/app/ml/spending_forecast.py +++ b/backend/app/ml/spending_forecast.py @@ -10,19 +10,55 @@ import pandas as pd warnings.filterwarnings("ignore") MIN_POINTS = 3 -FORECAST_MONTHS = 3 +FORECAST_MONTHS = 6 def _next_month_starts(from_date: date, n: int) -> list[str]: months = [] - d = (from_date.replace(day=1) + relativedelta(months=1)) + d = from_date.replace(day=1) + relativedelta(months=1) for _ in range(n): months.append(d.strftime("%Y-%m-%d")) d += relativedelta(months=1) return months -def _fit_holt(values: list[float], n: int) -> tuple[list[float], list[float], list[float]]: +def _fit_sarima(values: list[float], n: int) -> tuple[list[float], list[float], list[float], list[float], list[float], str]: + """ + Primary algorithm. Uses SARIMAX with seasonal component when enough data exists, + plain ARIMA otherwise. Returns (forecast, lower_80, upper_80, lower_95, upper_95, algorithm). + """ + from statsmodels.tsa.statespace.sarimax import SARIMAX + + series = np.array(values, dtype=float) + algo = "sarima" + + try: + if len(series) >= 12: + # Seasonal ARIMA with annual period + model = SARIMAX(series, order=(1, 1, 1), seasonal_order=(1, 0, 1, 12), + enforce_stationarity=False, enforce_invertibility=False) + else: + model = SARIMAX(series, order=(1, 1, 1), + enforce_stationarity=False, enforce_invertibility=False) + + fit = model.fit(disp=False, maxiter=200) + forecast_obj = fit.get_forecast(steps=n) + mean = forecast_obj.predicted_mean + ci_80 = forecast_obj.conf_int(alpha=0.20) # 80% interval + ci_95 = forecast_obj.conf_int(alpha=0.05) # 95% interval + + lower_80 = np.maximum(0, ci_80.iloc[:, 0].values).tolist() + upper_80 = ci_80.iloc[:, 1].values.tolist() + lower_95 = np.maximum(0, ci_95.iloc[:, 0].values).tolist() + upper_95 = ci_95.iloc[:, 1].values.tolist() + return mean.tolist(), lower_80, upper_80, lower_95, upper_95, algo + + except Exception: + return _fit_holt(values, n) + + +def _fit_holt(values: list[float], n: int) -> tuple[list[float], list[float], list[float], list[float], list[float], str]: + """Holt-Winters fallback.""" from statsmodels.tsa.holtwinters import ExponentialSmoothing try: @@ -36,13 +72,22 @@ def _fit_holt(values: list[float], n: int) -> tuple[list[float], list[float], li fit = model.fit(optimized=True, disp=False) forecast = fit.forecast(n) sigma = float(np.std(fit.resid)) if len(fit.resid) > 1 else float(np.mean(values) * 0.15) - lower = np.maximum(0, forecast - 1.28 * sigma) - upper = forecast + 1.28 * sigma - return forecast.tolist(), lower.tolist(), upper.tolist() + + lower_80 = np.maximum(0, forecast - 1.28 * sigma).tolist() + upper_80 = (forecast + 1.28 * sigma).tolist() + lower_95 = np.maximum(0, forecast - 1.96 * sigma).tolist() + upper_95 = (forecast + 1.96 * sigma).tolist() + return forecast.tolist(), lower_80, upper_80, lower_95, upper_95, "holt_winters" + except Exception: avg = float(np.mean(values)) sigma = float(np.std(values)) if len(values) > 1 else avg * 0.15 - return [avg] * n, [max(0, avg - 1.28 * sigma)] * n, [(avg + 1.28 * sigma)] * n + fcast = [avg] * n + lower_80 = [max(0.0, avg - 1.28 * sigma)] * n + upper_80 = [(avg + 1.28 * sigma)] * n + lower_95 = [max(0.0, avg - 1.96 * sigma)] * n + upper_95 = [(avg + 1.96 * sigma)] * n + return fcast, lower_80, upper_80, lower_95, upper_95, "average" def forecast_spending(df: pd.DataFrame) -> list[dict]: @@ -61,31 +106,47 @@ def forecast_spending(df: pd.DataFrame) -> list[dict]: group = group.sort_values("ds") values = group["y"].tolist() actuals = [ - {"date": row["ds"].strftime("%Y-%m-%d"), "amount": row["y"]} + {"date": row["ds"].strftime("%Y-%m-%d"), "amount": round(float(row["y"]), 2)} for _, row in group.iterrows() ] if len(values) < MIN_POINTS: avg = float(np.mean(values)) + sigma = avg * 0.15 forecast_pts = [ - {"date": d, "amount": round(avg, 2), "lower": round(avg * 0.7, 2), "upper": round(avg * 1.3, 2)} + { + "date": d, + "amount": round(avg, 2), + "lower": round(max(0.0, avg - 1.28 * sigma), 2), + "upper": round(avg + 1.28 * sigma, 2), + "lower_95": round(max(0.0, avg - 1.96 * sigma), 2), + "upper_95": round(avg + 1.96 * sigma, 2), + } for d in future_dates ] + algo = "average" else: - fcast, lower, upper = _fit_holt(values, FORECAST_MONTHS) + fcast, lower_80, upper_80, lower_95, upper_95, algo = _fit_sarima(values, FORECAST_MONTHS) forecast_pts = [ - {"date": d, "amount": round(max(0, f), 2), "lower": round(l, 2), "upper": round(u, 2)} - for d, f, l, u in zip(future_dates, fcast, lower, upper) + { + "date": d, + "amount": round(max(0.0, f), 2), + "lower": round(l80, 2), + "upper": round(u80, 2), + "lower_95": round(l95, 2), + "upper_95": round(u95, 2), + } + for d, f, l80, u80, l95, u95 in zip(future_dates, fcast, lower_80, upper_80, lower_95, upper_95) ] results.append({ - "category_id": cat_id, + "category_id": str(cat_id), "category_name": cat_name, "monthly_avg": round(float(np.mean(values)), 2), - "actuals": actuals[-6:], # last 6 months for display + "algorithm": algo, + "actuals": actuals[-12:], # last 12 months for display "forecast": forecast_pts, }) - # Sort by monthly_avg descending (highest spend first) results.sort(key=lambda x: x["monthly_avg"], reverse=True) return results diff --git a/backend/pyproject.toml b/backend/pyproject.toml index aae74a9..0a23186 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -17,7 +17,6 @@ dependencies = [ "qrcode[pil]>=8.0", "cryptography>=44.0", "yfinance>=0.2", - "prophet>=1.1", "statsmodels>=0.14", "numpy>=2.0", "scipy>=1.14", @@ -52,6 +51,7 @@ build-backend = "hatchling.build" [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "session" testpaths = ["tests"] [tool.hatch.build.targets.wheel] diff --git a/frontend/src/api/predictions.ts b/frontend/src/api/predictions.ts index 425c1e3..e1966de 100644 --- a/frontend/src/api/predictions.ts +++ b/frontend/src/api/predictions.ts @@ -4,8 +4,9 @@ export interface CategoryForecast { category_id: string; category_name: string; monthly_avg: number; + algorithm: "sarima" | "holt_winters" | "average"; actuals: { date: string; amount: number }[]; - forecast: { date: string; amount: number; lower: number; upper: number }[]; + forecast: { date: string; amount: number; lower: number; upper: number; lower_95: number; upper_95: number }[]; } export interface SpendingForecastResponse { diff --git a/frontend/src/pages/predictions/PredictionsPage.tsx b/frontend/src/pages/predictions/PredictionsPage.tsx index d7bc471..ceb6c3b 100644 --- a/frontend/src/pages/predictions/PredictionsPage.tsx +++ b/frontend/src/pages/predictions/PredictionsPage.tsx @@ -205,6 +205,12 @@ function BudgetForecastCard({ forecast: f }: { forecast: import("@/api/predictio // ─── Spending Forecast ─────────────────────────────────────────────────────── +const ALGO_LABELS: Record = { + sarima: "SARIMA", + holt_winters: "Holt-Winters", + average: "Avg", +}; + function SpendingTab() { const { data, isLoading } = useQuery({ queryKey: ["pred-spending"], queryFn: getSpendingForecast }); const [selected, setSelected] = useState(0); @@ -218,8 +224,10 @@ function SpendingTab() { ...cat.forecast.map(p => ({ date: p.date.slice(0, 7), forecast: p.amount, - lower: p.lower, - upper: p.upper, + lower_80: p.lower, + upper_80: p.upper, + lower_95: p.lower_95, + upper_95: p.upper_95, })), ]; @@ -246,29 +254,38 @@ function SpendingTab() {
-

{cat.category_name} — Spending Forecast

-

Shaded = 80% confidence interval

+
+

{cat.category_name} — 6-Month Forecast

+ + {ALGO_LABELS[cat.algorithm] ?? cat.algorithm} + +
+

Dark = 80% · Light = 95% confidence

- + `£${v}`} width={55} /> - formatCurrency(v, "GBP")} /> + [formatCurrency(v, "GBP"), name]} /> - + + + - {/* Confidence band as area overlay */} {cat.forecast.length > 0 && ( -
- Forecast next 3 months: {cat.forecast.map(f => - `${f.date.slice(0, 7)}: ${formatCurrency(f.amount, "GBP")} (${formatCurrency(f.lower, "GBP")}–${formatCurrency(f.upper, "GBP")})` - ).join(" · ")} +
+ {cat.forecast.slice(0, 3).map(f => ( +
+

{f.date.slice(0, 7)}

+

{formatCurrency(f.amount, "GBP")}

+

{formatCurrency(f.lower_95, "GBP")}–{formatCurrency(f.upper_95, "GBP")}

+
+ ))}
)}
-
); }