Initial commit: MyMidas personal finance tracker

Full-stack self-hosted finance app with FastAPI backend and React frontend.

Features:
- Accounts, transactions, budgets, investments with GBP base currency
- CSV import with auto-detection for 10 UK bank formats
- ML predictions: spending forecast, net worth projection, Monte Carlo
- 7 selectable themes (Obsidian, Arctic, Midnight, Vault, Terminal, Synthwave, Ledger)
- Receipt/document attachments on transactions (JPEG, PNG, WebP, PDF)
- AES-256-GCM field encryption, RS256 JWT, TOTP 2FA, RLS, audit log
- Encrypted nightly backups + key rotation script
- Mobile-responsive layout with bottom nav

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
megaproxy 2026-04-21 11:56:10 +00:00
commit 61a7884ee5
127 changed files with 13323 additions and 0 deletions

19
backend/Dockerfile Normal file
View file

@ -0,0 +1,19 @@
FROM python:3.12-slim AS base
RUN apt-get update && apt-get install -y --no-install-recommends \
libmagic1 \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir uv
WORKDIR /app
COPY pyproject.toml ./
FROM base AS deps
RUN uv pip install --system --no-cache -e .
FROM deps AS production
COPY app/ ./app/
COPY alembic/ ./alembic/
COPY alembic.ini ./
RUN useradd -r -s /bin/false -u 1001 appuser && chown -R appuser /app && mkdir -p /app/uploads && chown appuser /app/uploads
USER appuser
EXPOSE 8000
CMD ["sh", "-c", "python -m alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 2 --proxy-headers"]

41
backend/alembic.ini Normal file
View file

@ -0,0 +1,41 @@
[alembic]
script_location = alembic
prepend_sys_path = .
version_path_separator = os
sqlalchemy.url = postgresql://finance_app:password@postgres:5432/financedb
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

55
backend/alembic/env.py Normal file
View file

@ -0,0 +1,55 @@
import asyncio
import os
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from app.db.base import Base
import app.db.models # noqa: F401 — ensure all models are registered
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
# Override URL from env if available
db_url = os.environ.get("DATABASE_URL", config.get_main_option("sqlalchemy.url"))
if db_url and db_url.startswith("postgresql://"):
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
def run_migrations_offline() -> None:
context.configure(
url=db_url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
engine = create_async_engine(db_url)
async with engine.begin() as conn:
await conn.run_sync(do_run_migrations)
await engine.dispose()
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View file

@ -0,0 +1,308 @@
"""initial schema
Revision ID: 0001
Revises:
Create Date: 2026-04-20
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "0001"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# users
op.create_table(
"users",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("email", sa.Text, nullable=False, unique=True),
sa.Column("password_hash", sa.Text, nullable=False),
sa.Column("totp_secret", sa.Text, nullable=True),
sa.Column("totp_enabled", sa.Boolean, nullable=False, server_default="false"),
sa.Column("totp_backup_codes", sa.Text, nullable=True),
sa.Column("display_name", sa.Text, nullable=False),
sa.Column("base_currency", sa.String(10), nullable=False, server_default="GBP"),
sa.Column("theme", sa.String(20), nullable=False, server_default="dark"),
sa.Column("locale", sa.String(20), nullable=False, server_default="en-GB"),
sa.Column("failed_login_attempts", sa.Integer, nullable=False, server_default="0"),
sa.Column("locked_until", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_login_ip", postgresql.INET, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index("ix_users_email", "users", ["email"])
# sessions
op.create_table(
"sessions",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("token_hash", sa.Text, nullable=False, unique=True),
sa.Column("ip_address", postgresql.INET, nullable=True),
sa.Column("user_agent", sa.Text, nullable=True),
sa.Column("last_active_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
op.create_index("ix_sessions_token_hash", "sessions", ["token_hash"])
# accounts
op.create_table(
"accounts",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("name", sa.LargeBinary, nullable=False),
sa.Column("institution", sa.LargeBinary, nullable=True),
sa.Column("type", sa.String(30), nullable=False),
sa.Column("currency", sa.String(10), nullable=False),
sa.Column("current_balance", sa.Numeric(20, 8), nullable=False, server_default="0"),
sa.Column("credit_limit", sa.Numeric(20, 8), nullable=True),
sa.Column("interest_rate", sa.Numeric(8, 4), nullable=True),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
sa.Column("include_in_net_worth", sa.Boolean, nullable=False, server_default="true"),
sa.Column("color", sa.String(7), nullable=False, server_default="#6366f1"),
sa.Column("icon", sa.Text, nullable=True),
sa.Column("notes", sa.LargeBinary, nullable=True),
sa.Column("meta", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index("ix_accounts_user_id", "accounts", ["user_id"])
# categories
op.create_table(
"categories",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=True),
sa.Column("name", sa.Text, nullable=False),
sa.Column("parent_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("categories.id"), nullable=True),
sa.Column("type", sa.String(20), nullable=False),
sa.Column("icon", sa.Text, nullable=True),
sa.Column("color", sa.String(7), nullable=True),
sa.Column("is_system", sa.Boolean, nullable=False, server_default="false"),
sa.Column("sort_order", sa.Integer, nullable=False, server_default="0"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
# transactions
op.create_table(
"transactions",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("account_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("accounts.id"), nullable=False),
sa.Column("transfer_account_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("accounts.id"), nullable=True),
sa.Column("category_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("categories.id"), nullable=True),
sa.Column("type", sa.String(20), nullable=False),
sa.Column("status", sa.String(20), nullable=False, server_default="cleared"),
sa.Column("amount", sa.Numeric(20, 8), nullable=False),
sa.Column("amount_base", sa.Numeric(20, 8), nullable=True),
sa.Column("currency", sa.String(10), nullable=False),
sa.Column("base_currency", sa.String(10), nullable=False),
sa.Column("exchange_rate", sa.Numeric(20, 10), nullable=True),
sa.Column("date", sa.Date, nullable=False),
sa.Column("description", sa.LargeBinary, nullable=False),
sa.Column("merchant", sa.LargeBinary, nullable=True),
sa.Column("notes", sa.LargeBinary, nullable=True),
sa.Column("tags", postgresql.ARRAY(sa.Text), nullable=False, server_default="{}"),
sa.Column("is_recurring", sa.Boolean, nullable=False, server_default="false"),
sa.Column("recurring_rule", postgresql.JSONB, nullable=True),
sa.Column("attachment_refs", postgresql.JSONB, nullable=False, server_default="[]"),
sa.Column("import_hash", sa.Text, nullable=True),
sa.Column("meta", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index("ix_transactions_user_id_date", "transactions", ["user_id", "date"])
op.create_index("ix_transactions_account_id", "transactions", ["account_id"])
op.create_index("ix_transactions_import_hash", "transactions", ["import_hash"])
# budgets
op.create_table(
"budgets",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("category_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("categories.id"), nullable=False),
sa.Column("name", sa.Text, nullable=False),
sa.Column("amount", sa.Numeric(20, 8), nullable=False),
sa.Column("currency", sa.String(10), nullable=False),
sa.Column("period", sa.String(20), nullable=False),
sa.Column("start_date", sa.Date, nullable=False),
sa.Column("end_date", sa.Date, nullable=True),
sa.Column("rollover", sa.Boolean, nullable=False, server_default="false"),
sa.Column("alert_threshold", sa.Numeric(5, 2), nullable=False, server_default="80"),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
# assets
op.create_table(
"assets",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("symbol", sa.Text, nullable=False),
sa.Column("name", sa.Text, nullable=False),
sa.Column("type", sa.String(30), nullable=False),
sa.Column("currency", sa.String(10), nullable=False),
sa.Column("exchange", sa.Text, nullable=True),
sa.Column("isin", sa.String(12), nullable=True),
sa.Column("data_source", sa.String(30), nullable=False, server_default="yahoo_finance"),
sa.Column("data_source_id", sa.Text, nullable=True),
sa.Column("last_price", sa.Numeric(20, 8), nullable=True),
sa.Column("last_price_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("price_change_24h", sa.Numeric(10, 4), nullable=True),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_unique_constraint("uq_assets_symbol_exchange", "assets", ["symbol", "exchange"])
# asset_prices
op.create_table(
"asset_prices",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("date", sa.Date, nullable=False),
sa.Column("open", sa.Numeric(20, 8), nullable=True),
sa.Column("high", sa.Numeric(20, 8), nullable=True),
sa.Column("low", sa.Numeric(20, 8), nullable=True),
sa.Column("close", sa.Numeric(20, 8), nullable=False),
sa.Column("volume", sa.Numeric(30, 8), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_unique_constraint("uq_asset_prices_asset_date", "asset_prices", ["asset_id", "date"])
# investment_holdings
op.create_table(
"investment_holdings",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("account_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("accounts.id"), nullable=False),
sa.Column("asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("assets.id"), nullable=False),
sa.Column("quantity", sa.Numeric(30, 10), nullable=False, server_default="0"),
sa.Column("avg_cost_basis", sa.Numeric(20, 8), nullable=False, server_default="0"),
sa.Column("currency", sa.String(10), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_unique_constraint("uq_holdings_account_asset", "investment_holdings", ["account_id", "asset_id"])
# investment_transactions
op.create_table(
"investment_transactions",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("holding_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("investment_holdings.id"), nullable=False),
sa.Column("transaction_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("transactions.id"), nullable=True),
sa.Column("type", sa.String(20), nullable=False),
sa.Column("quantity", sa.Numeric(30, 10), nullable=False),
sa.Column("price", sa.Numeric(20, 8), nullable=False),
sa.Column("fees", sa.Numeric(20, 8), nullable=False, server_default="0"),
sa.Column("total_amount", sa.Numeric(20, 8), nullable=False),
sa.Column("currency", sa.String(10), nullable=False),
sa.Column("date", sa.Date, nullable=False),
sa.Column("notes", sa.LargeBinary, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
# currencies
op.create_table(
"currencies",
sa.Column("code", sa.String(10), primary_key=True),
sa.Column("name", sa.Text, nullable=False),
sa.Column("symbol", sa.String(5), nullable=False),
sa.Column("is_crypto", sa.Boolean, nullable=False, server_default="false"),
sa.Column("decimal_places", sa.Integer, nullable=False, server_default="2"),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
)
# exchange_rates
op.create_table(
"exchange_rates",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("base_currency", sa.String(10), nullable=False),
sa.Column("quote_currency", sa.String(10), nullable=False),
sa.Column("rate", sa.Numeric(20, 10), nullable=False),
sa.Column("source", sa.String(50), nullable=False),
sa.Column("fetched_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_exchange_rates_pair", "exchange_rates", ["base_currency", "quote_currency", "fetched_at"])
# net_worth_snapshots
op.create_table(
"net_worth_snapshots",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("date", sa.Date, nullable=False),
sa.Column("total_assets", sa.Numeric(20, 8), nullable=False),
sa.Column("total_liabilities", sa.Numeric(20, 8), nullable=False),
sa.Column("net_worth", sa.Numeric(20, 8), nullable=False),
sa.Column("base_currency", sa.String(10), nullable=False),
sa.Column("breakdown", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_unique_constraint("uq_nw_snapshots_user_date", "net_worth_snapshots", ["user_id", "date"])
# audit_logs
op.create_table(
"audit_logs",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
sa.Column("action", sa.String(50), nullable=False),
sa.Column("resource_type", sa.Text, nullable=True),
sa.Column("resource_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("ip_address", postgresql.INET, nullable=True),
sa.Column("user_agent", sa.Text, nullable=True),
sa.Column("metadata", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("success", sa.Boolean, nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_audit_logs_user_id", "audit_logs", ["user_id"])
op.create_index("ix_audit_logs_action", "audit_logs", ["action"])
# Enable RLS on user-owned tables
for table in ["accounts", "transactions", "budgets", "investment_holdings",
"investment_transactions", "net_worth_snapshots"]:
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY {table}_user_isolation ON {table}
USING (user_id = current_app_user_id())
""")
# Seed default currencies
op.execute("""
INSERT INTO currencies (code, name, symbol, is_crypto, decimal_places) VALUES
('GBP', 'British Pound', '£', false, 2),
('USD', 'US Dollar', '$', false, 2),
('EUR', 'Euro', '', false, 2),
('JPY', 'Japanese Yen', '¥', false, 0),
('CAD', 'Canadian Dollar', 'CA$', false, 2),
('AUD', 'Australian Dollar', 'A$', false, 2),
('CHF', 'Swiss Franc', 'Fr', false, 2),
('BTC', 'Bitcoin', '', true, 8),
('ETH', 'Ethereum', 'Ξ', true, 8)
ON CONFLICT (code) DO NOTHING
""")
def downgrade() -> None:
for table in ["accounts", "transactions", "budgets", "investment_holdings",
"investment_transactions", "net_worth_snapshots"]:
op.execute(f"DROP POLICY IF EXISTS {table}_user_isolation ON {table}")
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
for table in [
"audit_logs", "net_worth_snapshots", "exchange_rates", "currencies",
"investment_transactions", "investment_holdings", "asset_prices", "assets",
"budgets", "transactions", "categories", "accounts", "sessions", "users",
]:
op.drop_table(table)

0
backend/app/__init__.py Normal file
View file

View file

14
backend/app/api/router.py Normal file
View file

@ -0,0 +1,14 @@
from fastapi import APIRouter
from app.api.v1 import auth, users, accounts, categories, transactions, budgets, reports, investments, predictions
router = APIRouter()
router.include_router(auth.router, prefix="/auth", tags=["auth"])
router.include_router(users.router, prefix="/users", tags=["users"])
router.include_router(accounts.router, prefix="/accounts", tags=["accounts"])
router.include_router(categories.router, prefix="/categories", tags=["categories"])
router.include_router(transactions.router, prefix="/transactions", tags=["transactions"])
router.include_router(budgets.router)
router.include_router(reports.router)
router.include_router(investments.router)
router.include_router(predictions.router)

View file

View file

@ -0,0 +1,236 @@
import uuid
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.audit import write_audit
from app.dependencies import get_current_user, get_db
from app.schemas.account import AccountCreate, AccountResponse, AccountUpdate
from app.services.account_service import (
AccountError,
create_account,
delete_account,
get_account,
get_net_worth,
list_accounts,
update_account,
)
MAX_IMPORT_FILE_BYTES = 10 * 1024 * 1024 # 10 MB
MAX_IMPORT_ROWS = 50_000
router = APIRouter()
@router.get("", response_model=list[AccountResponse])
async def get_accounts(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
return await list_accounts(db, user.id)
@router.post("", response_model=AccountResponse, status_code=201)
async def create(
body: AccountCreate,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
result = await create_account(db, user.id, body)
await write_audit(db, user_id=user.id, action="account_create")
await db.commit()
return result
@router.get("/net-worth")
async def net_worth(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
return await get_net_worth(db, user.id, user.base_currency)
@router.get("/{account_id}", response_model=AccountResponse)
async def get_one(
account_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
account = await get_account(db, account_id, user.id)
from app.services.account_service import _to_response
return _to_response(account)
except AccountError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@router.put("/{account_id}", response_model=AccountResponse)
async def update(
account_id: uuid.UUID,
body: AccountUpdate,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
result = await update_account(db, account_id, user.id, body)
await write_audit(db, user_id=user.id, action="account_update", resource_type="account", resource_id=account_id)
await db.commit()
return result
except AccountError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@router.post("/{account_id}/import/preview")
async def import_preview(
account_id: uuid.UUID,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
"""Upload a CSV and get back the detected format, column mapping, and a sample of parsed rows."""
from app.services.csv_detector import parse_csv_content, detect_format
try:
await get_account(db, account_id, user.id)
except AccountError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
content = await file.read(MAX_IMPORT_FILE_BYTES + 1)
if len(content) > MAX_IMPORT_FILE_BYTES:
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
try:
headers, rows = parse_csv_content(content)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if not headers:
raise HTTPException(status_code=400, detail="Could not read CSV headers")
mapping = detect_format(headers)
# Build 5-row preview using the detected mapping
preview = []
for row in rows[:5]:
entry: dict = {
"date_raw": row.get(mapping.date, ""),
"description_raw": row.get(mapping.description, ""),
}
if mapping.is_split():
debit_str = row.get(mapping.debit or "", "").replace(",", "").replace("£", "").strip()
credit_str = row.get(mapping.credit or "", "").replace(",", "").replace("£", "").strip()
try:
debit = float(debit_str) if debit_str else 0.0
credit = float(credit_str) if credit_str else 0.0
entry["amount_raw"] = credit - debit
except ValueError:
entry["amount_raw"] = None
else:
raw = row.get(mapping.amount or "", "").replace(",", "").replace("£", "").strip()
try:
entry["amount_raw"] = float(raw) if raw else None
except ValueError:
entry["amount_raw"] = None
if mapping.balance:
entry["balance_raw"] = row.get(mapping.balance, "")
preview.append(entry)
return {
"detected_format": mapping.detected_format,
"headers": headers,
"mapping": {
"date": mapping.date,
"description": mapping.description,
"amount": mapping.amount,
"debit": mapping.debit,
"credit": mapping.credit,
"balance": mapping.balance,
"reference": mapping.reference,
},
"total_rows": len(rows),
"preview": preview,
}
@router.post("/{account_id}/import")
async def import_csv_to_account(
account_id: uuid.UUID,
file: UploadFile = File(...),
date_col: str = Form(...),
description_col: str = Form(...),
amount_col: str = Form(default=""),
debit_col: str = Form(default=""),
credit_col: str = Form(default=""),
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
from app.services.csv_detector import parse_csv_content
from app.services.transaction_service import import_csv
from app.core.audit import write_audit
try:
await get_account(db, account_id, user.id)
except AccountError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
content = await file.read(MAX_IMPORT_FILE_BYTES + 1)
if len(content) > MAX_IMPORT_FILE_BYTES:
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
try:
_, rows = parse_csv_content(content)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if len(rows) > MAX_IMPORT_ROWS:
raise HTTPException(status_code=400, detail=f"File contains too many rows (max {MAX_IMPORT_ROWS:,})")
use_split = bool(debit_col and credit_col)
parsed_rows = []
for row in rows:
date_val = row.get(date_col, "").strip()
desc_val = row.get(description_col, "").strip() or "Imported transaction"
if use_split:
debit_str = row.get(debit_col, "").replace(",", "").replace("£", "").strip()
credit_str = row.get(credit_col, "").replace(",", "").replace("£", "").strip()
try:
debit = float(debit_str) if debit_str else 0.0
credit = float(credit_str) if credit_str else 0.0
amount = credit - debit
except ValueError:
continue
else:
raw = row.get(amount_col, "").replace(",", "").replace("£", "").strip()
try:
amount = float(raw) if raw else None
except ValueError:
continue
if amount is None:
continue
if not date_val:
continue
parsed_rows.append({"date": date_val, "description": desc_val, "amount": str(amount)})
if not parsed_rows:
raise HTTPException(status_code=400, detail="No valid rows found after applying column mapping")
result = await import_csv(db, user.id, account_id, parsed_rows, user.base_currency)
await write_audit(db, user_id=user.id, action="import_data", metadata=result)
await db.commit()
return result
@router.delete("/{account_id}", status_code=204)
async def delete(
account_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
await delete_account(db, account_id, user.id)
await write_audit(db, user_id=user.id, action="account_delete", resource_type="account", resource_id=account_id)
await db.commit()
except AccountError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)

342
backend/app/api/v1/auth.py Normal file
View file

@ -0,0 +1,342 @@
"""
Auth endpoints: register, login, TOTP, refresh, logout, sessions.
"""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.audit import write_audit
from app.core.rate_limiter import is_rate_limited
from app.core.security import create_refresh_token, decode_token, generate_csrf_token, hash_token
from app.dependencies import get_current_user, get_db, get_redis
from app.schemas.auth import (
LoginRequest,
RegisterRequest,
SessionInfo,
TOTPChallengeResponse,
TOTPLoginRequest,
TOTPSetupResponse,
TOTPVerifyRequest,
TokenResponse,
)
from app.services.auth_service import (
AuthError,
authenticate_user,
complete_totp_login,
create_totp_challenge_token,
disable_totp,
enable_totp,
get_sessions,
register_user,
revoke_all_sessions,
revoke_session,
setup_totp,
)
router = APIRouter()
def _ip(request: Request) -> str | None:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else None
def _ua(request: Request) -> str | None:
return request.headers.get("User-Agent")
def _set_refresh_cookie(response: Response, token: str) -> None:
response.set_cookie(
"refresh_token",
token,
httponly=True,
secure=True,
samesite="strict",
max_age=7 * 24 * 3600,
path="/api/v1/auth",
)
def _set_csrf_cookie(response: Response, token: str) -> None:
response.set_cookie(
"csrf_token",
token,
httponly=False,
secure=True,
samesite="strict",
max_age=86400,
)
@router.post("/register", status_code=201)
async def register(
body: RegisterRequest,
request: Request,
response: Response,
db: AsyncSession = Depends(get_db),
):
try:
user = await register_user(db, body.email, body.password, body.display_name)
await write_audit(db, user_id=user.id, action="register", ip_address=_ip(request))
await db.commit()
except AuthError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
return {"message": "Account created. Please log in."}
@router.post("/login")
async def login(
body: LoginRequest,
request: Request,
response: Response,
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
try:
user, access_token, refresh_token = await authenticate_user(
db, redis, body.email, body.password, _ip(request), _ua(request)
)
except AuthError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
if access_token is None:
# TOTP required
challenge_token = create_totp_challenge_token(user.id)
await write_audit(db, user_id=user.id, action="login", ip_address=_ip(request), metadata={"totp_required": True})
await db.commit()
return TOTPChallengeResponse(challenge_token=challenge_token)
csrf = generate_csrf_token()
_set_refresh_cookie(response, refresh_token)
_set_csrf_cookie(response, csrf)
await write_audit(db, user_id=user.id, action="login", ip_address=_ip(request))
await db.commit()
settings_expire = 15 * 60
return TokenResponse(access_token=access_token, expires_in=settings_expire)
@router.post("/login/totp")
async def login_totp(
body: TOTPLoginRequest,
request: Request,
response: Response,
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
ip = _ip(request) or "unknown"
limited, _ = await is_rate_limited(redis, f"rate:totp:{ip}", limit=10, window_seconds=60)
if limited:
raise HTTPException(status_code=429, detail="Too many TOTP attempts — try again shortly")
try:
access_token, refresh_token = await complete_totp_login(
db, body.challenge_token, body.totp_code, _ip(request), _ua(request)
)
except AuthError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
csrf = generate_csrf_token()
_set_refresh_cookie(response, refresh_token)
_set_csrf_cookie(response, csrf)
await db.commit()
return TokenResponse(access_token=access_token, expires_in=15 * 60)
@router.post("/refresh")
async def refresh_token(
request: Request,
response: Response,
db: AsyncSession = Depends(get_db),
):
token = request.cookies.get("refresh_token")
if not token:
raise HTTPException(status_code=401, detail="No refresh token")
try:
payload = decode_token(token, token_type="refresh")
except Exception:
raise HTTPException(status_code=401, detail="Invalid refresh token")
import uuid
from app.core.security import create_access_token
from sqlalchemy import select
from datetime import datetime, timezone
from app.db.models.session import Session
user_id = uuid.UUID(payload["sub"])
now = datetime.now(timezone.utc)
# Find and update session
result = await db.execute(
select(Session).where(
Session.user_id == user_id,
Session.revoked_at.is_(None),
Session.expires_at > now,
)
)
session = result.scalars().first()
if not session:
raise HTTPException(status_code=401, detail="Session not found")
new_access = create_access_token(str(user_id))
new_refresh = create_refresh_token(str(user_id))
# Rotate session token hash
session.token_hash = hash_token(new_access)
session.last_active_at = now
await db.commit()
csrf = generate_csrf_token()
_set_refresh_cookie(response, new_refresh)
_set_csrf_cookie(response, csrf)
return TokenResponse(access_token=new_access, expires_in=15 * 60)
@router.post("/logout")
async def logout(
request: Request,
response: Response,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
token = request.headers.get("Authorization", "")[7:]
th = hash_token(token)
await revoke_session_by_hash(db, th, user.id)
await write_audit(db, user_id=user.id, action="logout", ip_address=_ip(request))
await db.commit()
response.delete_cookie("refresh_token", path="/api/v1/auth")
response.delete_cookie("csrf_token")
return {"message": "Logged out"}
async def revoke_session_by_hash(db, token_hash: str, user_id):
from sqlalchemy import select, update
from datetime import datetime, timezone
from app.db.models.session import Session
await db.execute(
update(Session)
.where(Session.user_id == user_id, Session.token_hash == token_hash)
.values(revoked_at=datetime.now(timezone.utc))
)
@router.post("/logout-all")
async def logout_all(
request: Request,
response: Response,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
token = request.headers.get("Authorization", "")[7:]
await revoke_all_sessions(db, user.id)
await write_audit(db, user_id=user.id, action="logout_all", ip_address=_ip(request))
await db.commit()
response.delete_cookie("refresh_token", path="/api/v1/auth")
response.delete_cookie("csrf_token")
return {"message": "All sessions revoked"}
@router.get("/sessions", response_model=list[SessionInfo])
async def list_sessions(
request: Request,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
token = request.headers.get("Authorization", "")[7:]
current_hash = hash_token(token)
sessions = await get_sessions(db, user.id)
result = []
for s in sessions:
info = SessionInfo.model_validate(s)
info.is_current = (s.token_hash == current_hash)
result.append(info)
return result
@router.delete("/sessions/{session_id}", status_code=204)
async def delete_session(
session_id,
request: Request,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
import uuid
try:
sid = uuid.UUID(str(session_id))
except ValueError:
raise HTTPException(status_code=422, detail="Invalid session ID")
try:
await revoke_session(db, sid, user.id)
except AuthError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
await write_audit(db, user_id=user.id, action="session_revoke", resource_type="session", resource_id=sid)
await db.commit()
@router.get("/totp/setup", response_model=TOTPSetupResponse)
async def totp_setup(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
secret, qr_b64, backup_codes = await setup_totp(user, db)
return TOTPSetupResponse(secret=secret, qr_code_png_b64=qr_b64, backup_codes=backup_codes)
@router.post("/totp/verify", status_code=200)
async def totp_verify(
body: TOTPVerifyRequest,
request: Request,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
# Secret must be passed back from setup — here we expect it stored temporarily in body
# In practice the client stores it until verification; it's never persisted until verified
# This endpoint receives the secret + verification code
# For simplicity we accept: {"secret": "...", "code": "..."}
# Redefine body inline:
raise HTTPException(status_code=400, detail="Use /totp/enable endpoint with secret and code")
@router.post("/totp/enable", status_code=200)
async def totp_enable(
body: dict,
request: Request,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
secret = body.get("secret")
code = body.get("code")
if not secret or not code:
raise HTTPException(status_code=422, detail="secret and code required")
try:
await enable_totp(user, db, secret, code)
except AuthError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
await write_audit(db, user_id=user.id, action="totp_enable", ip_address=_ip(request))
await db.commit()
return {"message": "TOTP enabled"}
@router.delete("/totp", status_code=200)
async def totp_disable(
body: dict,
request: Request,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
password = body.get("password")
if not password:
raise HTTPException(status_code=422, detail="password required")
try:
await disable_totp(user, db, password)
except AuthError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
await write_audit(db, user_id=user.id, action="totp_disable", ip_address=_ip(request))
await db.commit()
return {"message": "TOTP disabled"}

View file

@ -0,0 +1,79 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.db.models.user import User
from app.schemas.budget import BudgetCreate, BudgetResponse, BudgetSummaryItem, BudgetUpdate
from app.services import budget_service
router = APIRouter(prefix="/budgets", tags=["budgets"])
@router.get("", response_model=list[BudgetResponse])
async def list_budgets(
active_only: bool = True,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await budget_service.list_budgets(db, current_user.id, active_only)
@router.post("", response_model=BudgetResponse, status_code=status.HTTP_201_CREATED)
async def create_budget(
data: BudgetCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
budget = await budget_service.create_budget(db, current_user.id, data)
await db.commit()
return budget
@router.get("/summary", response_model=list[BudgetSummaryItem])
async def get_budget_summary(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await budget_service.get_budget_summary(db, current_user.id)
@router.get("/{budget_id}", response_model=BudgetResponse)
async def get_budget(
budget_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
budget = await budget_service.get_budget(db, current_user.id, budget_id)
if not budget:
raise HTTPException(status_code=404, detail="Budget not found")
return budget
@router.put("/{budget_id}", response_model=BudgetResponse)
async def update_budget(
budget_id: uuid.UUID,
data: BudgetUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
budget = await budget_service.get_budget(db, current_user.id, budget_id)
if not budget:
raise HTTPException(status_code=404, detail="Budget not found")
budget = await budget_service.update_budget(db, budget, data)
await db.commit()
return budget
@router.delete("/{budget_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_budget(
budget_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
budget = await budget_service.get_budget(db, current_user.id, budget_id)
if not budget:
raise HTTPException(status_code=404, detail="Budget not found")
await budget_service.delete_budget(db, budget)
await db.commit()

View file

@ -0,0 +1,36 @@
import uuid
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.services.category_service import create_category, list_categories
router = APIRouter()
@router.get("")
async def get_categories(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
return await list_categories(db, user.id)
@router.post("", status_code=201)
async def create(
body: dict,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
result = await create_category(
db,
user_id=user.id,
name=body["name"],
type_=body["type"],
icon=body.get("icon"),
color=body.get("color"),
parent_id=uuid.UUID(body["parent_id"]) if body.get("parent_id") else None,
)
await db.commit()
return result

View file

@ -0,0 +1,199 @@
import uuid
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.db.models.user import User
from app.schemas.investment import (
AssetSearch,
AssetPricePoint,
HoldingCreate,
HoldingResponse,
InvestmentTxnCreate,
InvestmentTxnResponse,
PerformanceMetrics,
PortfolioSummary,
)
from app.services import investment_service
from app.services.price_feed_service import search_yahoo, fetch_history
router = APIRouter(tags=["investments"])
# ── Portfolio ──────────────────────────────────────────────────────────────
@router.get("/investments/portfolio", response_model=PortfolioSummary)
async def get_portfolio(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await investment_service.get_portfolio(db, current_user.id)
@router.get("/investments/performance", response_model=PerformanceMetrics)
async def get_performance(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await investment_service.get_performance(db, current_user.id)
# ── Holdings ───────────────────────────────────────────────────────────────
@router.post("/investments/holdings", response_model=HoldingResponse, status_code=status.HTTP_201_CREATED)
async def create_holding(
data: HoldingCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.db.models.asset import Asset
from sqlalchemy import select
asset_result = await db.execute(select(Asset).where(Asset.id == data.asset_id))
asset = asset_result.scalar_one_or_none()
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
holding = await investment_service.create_holding(db, current_user.id, data)
await db.commit()
await db.refresh(holding)
return investment_service._holding_to_response(holding, asset)
@router.delete("/investments/holdings/{holding_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_holding(
holding_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
holding = await investment_service.get_holding(db, current_user.id, holding_id)
if not holding:
raise HTTPException(status_code=404, detail="Holding not found")
await db.delete(holding)
await db.commit()
# ── Investment transactions ────────────────────────────────────────────────
@router.get("/investments/holdings/{holding_id}/transactions", response_model=list[InvestmentTxnResponse])
async def list_transactions(
holding_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await investment_service.list_investment_transactions(db, current_user.id, holding_id)
@router.post("/investments/transactions", response_model=InvestmentTxnResponse, status_code=status.HTTP_201_CREATED)
async def add_transaction(
data: InvestmentTxnCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
txn = await investment_service.add_investment_transaction(db, current_user.id, data)
await db.commit()
return txn
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# ── Assets ─────────────────────────────────────────────────────────────────
@router.get("/assets/search", response_model=list[AssetSearch])
async def search_assets(
q: str = Query(..., min_length=1),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
# First search the local DB
local = await investment_service.search_assets(db, q)
if local:
from app.db.models.asset import Asset
return [AssetSearch(
id=a.id, symbol=a.symbol, name=a.name, type=a.type,
currency=a.currency, exchange=a.exchange,
last_price=a.last_price, price_change_24h=a.price_change_24h,
data_source=a.data_source,
) for a in local]
# Fall back to live Yahoo search
import asyncio
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, search_yahoo, q)
if not results:
return []
# Upsert into DB so future searches are local
created = []
for r in results:
asset = await investment_service.get_or_create_asset(
db, r["symbol"], r["name"], r["type"],
r["currency"], r["data_source"], r.get("data_source_id"),
r.get("exchange"),
)
created.append(asset)
await db.commit()
return [AssetSearch(
id=a.id, symbol=a.symbol, name=a.name, type=a.type,
currency=a.currency, exchange=a.exchange,
last_price=a.last_price, price_change_24h=a.price_change_24h,
data_source=a.data_source,
) for a in created]
@router.get("/assets/{asset_id}/prices", response_model=list[AssetPricePoint])
async def get_price_history(
asset_id: uuid.UUID,
days: int = Query(default=365, ge=7, le=1825),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
from app.db.models.asset import Asset
from sqlalchemy import select
asset_result = await db.execute(select(Asset).where(Asset.id == asset_id))
asset = asset_result.scalar_one_or_none()
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
# Fetch from DB; if sparse, refresh from Yahoo
prices = await investment_service.get_price_history(db, asset_id, days)
if len(prices) < 5 and asset.data_source == "yahoo_finance":
rows = await fetch_history(asset.symbol, days)
if rows:
await investment_service.upsert_price_history(db, asset_id, rows)
await db.commit()
prices = await investment_service.get_price_history(db, asset_id, days)
return [
AssetPricePoint(
date=p.date, open=p.open, high=p.high, low=p.low,
close=p.close, volume=p.volume,
)
for p in prices
]
@router.post("/assets", response_model=AssetSearch, status_code=status.HTTP_201_CREATED)
async def create_asset(
symbol: str,
name: str,
asset_type: str = "stock",
currency: str = "GBP",
data_source: str = "yahoo_finance",
data_source_id: str | None = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
asset = await investment_service.get_or_create_asset(
db, symbol, name, asset_type, currency, data_source, data_source_id
)
await db.commit()
return AssetSearch(
id=asset.id, symbol=asset.symbol, name=asset.name, type=asset.type,
currency=asset.currency, exchange=asset.exchange,
last_price=asset.last_price, price_change_24h=asset.price_change_24h,
data_source=asset.data_source,
)

View file

@ -0,0 +1,236 @@
from __future__ import annotations
from datetime import date
import calendar
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.rate_limiter import is_rate_limited
from app.dependencies import get_current_user, get_db, get_redis
from app.ml.feature_engineering import (
get_monthly_category_spending,
get_monthly_net_worth,
get_current_month_spending,
get_portfolio_monthly_returns,
get_daily_cash_flow,
)
from app.ml.spending_forecast import forecast_spending
from app.ml.net_worth_projection import project_net_worth
from app.ml.monte_carlo import run_monte_carlo
router = APIRouter(prefix="/predictions", tags=["predictions"])
async def _check_prediction_rate(redis: Redis, user_id: str) -> None:
limited, _ = await is_rate_limited(redis, f"rate:pred:{user_id}", limit=20, window_seconds=60)
if limited:
raise HTTPException(status_code=429, detail="Too many prediction requests — try again shortly")
@router.get("/spending")
async def spending_forecast(
request: Request,
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
user=Depends(get_current_user),
):
await _check_prediction_rate(redis, str(user.id))
df = await get_monthly_category_spending(db, user.id)
categories = forecast_spending(df)
return {"categories": categories}
@router.get("/net-worth")
async def net_worth_projection(
years: int = 5,
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
user=Depends(get_current_user),
):
await _check_prediction_rate(redis, str(user.id))
years = max(1, min(10, years))
df = await get_monthly_net_worth(db, user.id)
result = project_net_worth(df, years=years)
return result
class MonteCarloRequest(BaseModel):
years: int = 5
n_simulations: int = 1000
annual_contribution: float = 0.0
@router.post("/monte-carlo")
async def monte_carlo(
body: MonteCarloRequest,
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
user=Depends(get_current_user),
):
await _check_prediction_rate(redis, str(user.id))
years = max(1, min(10, body.years))
n_sims = max(100, min(5000, body.n_simulations))
# Get portfolio holdings
result = await db.execute(text("""
SELECT h.id, a.symbol, h.quantity::float, a.last_price::float,
(h.quantity * COALESCE(a.last_price, h.avg_cost_basis))::float AS current_value,
h.currency
FROM investment_holdings h
JOIN assets a ON a.id = h.asset_id
WHERE h.user_id = CAST(:uid AS uuid)
AND h.deleted_at IS NULL
AND h.quantity > 0
"""), {"uid": str(user.id)})
holdings = [
{"symbol": r[1], "quantity": r[2], "last_price": r[3], "current_value": r[4]}
for r in result.fetchall()
]
prices_df = await get_portfolio_monthly_returns(db, user.id)
mc = run_monte_carlo(
prices_df=prices_df,
holdings=holdings,
years=years,
n_sims=n_sims,
annual_contribution=body.annual_contribution,
)
return mc
@router.get("/budget-forecast")
async def budget_forecast(
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
user=Depends(get_current_user),
):
await _check_prediction_rate(redis, str(user.id))
today = date.today()
days_in_month = calendar.monthrange(today.year, today.month)[1]
day_of_month = today.day
days_remaining = days_in_month - day_of_month
# Get budgets
bgt_result = await db.execute(text("""
SELECT b.id::text, COALESCE(c.name, 'Uncategorised') AS cat_name,
b.category_id::text, b.amount::float
FROM budgets b
LEFT JOIN categories c ON c.id = b.category_id
WHERE b.user_id = CAST(:uid AS uuid)
AND b.period = 'monthly'
AND (b.end_date IS NULL OR b.end_date >= CURRENT_DATE)
"""), {"uid": str(user.id)})
budgets = {r[2]: {"budget_id": r[0], "category_name": r[1], "amount": r[3]} for r in bgt_result.fetchall()}
if not budgets:
return {"forecasts": [], "message": "No monthly budgets set"}
# Get current month spending per category
spent_df = await get_current_month_spending(db, user.id)
spent_map = {row["category_id"]: row["spent"] for _, row in spent_df.iterrows()}
forecasts = []
for cat_id, bgt in budgets.items():
spent = spent_map.get(cat_id, 0.0)
budget_amt = bgt["amount"]
# Daily velocity
velocity = spent / day_of_month if day_of_month > 0 else 0.0
forecast_total = spent + velocity * days_remaining
# Probability of overspend using a rough normal distribution
# Assume uncertainty grows with days remaining
import math
sigma = velocity * math.sqrt(days_remaining) * 0.3 if velocity > 0 else 1.0
if sigma > 0:
z = (budget_amt - forecast_total) / sigma
# CDF of normal
import scipy.stats
prob_overspend = float(1 - scipy.stats.norm.cdf(z))
else:
prob_overspend = 1.0 if forecast_total > budget_amt else 0.0
forecasts.append({
"category_id": cat_id,
"category_name": bgt["category_name"],
"budget_amount": round(budget_amt, 2),
"spent_so_far": round(spent, 2),
"forecast_month_total": round(max(spent, forecast_total), 2),
"daily_velocity": round(velocity, 2),
"probability_overspend": round(prob_overspend, 3),
"days_remaining": days_remaining,
})
forecasts.sort(key=lambda x: x["probability_overspend"], reverse=True)
return {"forecasts": forecasts}
@router.get("/cashflow")
async def cashflow_forecast(
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
user=Depends(get_current_user),
):
await _check_prediction_rate(redis, str(user.id))
from datetime import timedelta
import numpy as np
# Historical daily cash flows (last 90 days)
hist_df = await get_daily_cash_flow(db, user.id, days=90)
# Get current account balances
acct_result = await db.execute(text("""
SELECT SUM(
CASE WHEN type IN ('credit_card','loan','mortgage') THEN -ABS(current_balance)
ELSE current_balance END
)::float AS total_balance
FROM accounts
WHERE user_id = CAST(:uid AS uuid)
AND is_active = TRUE
AND include_in_net_worth = TRUE
AND deleted_at IS NULL
"""), {"uid": str(user.id)})
row = acct_result.fetchone()
current_balance = float(row[0] or 0.0)
# Compute average daily inflow / outflow from history
if not hist_df.empty:
avg_inflow = float(hist_df["inflow"].mean())
avg_outflow = float(hist_df["outflow"].mean())
std_net = float((hist_df["inflow"] - hist_df["outflow"]).std())
else:
avg_inflow = 0.0
avg_outflow = 0.0
std_net = 0.0
# Project 30 days forward
today = date.today()
daily = []
running_balance = current_balance
for i in range(1, 31):
d = today + timedelta(days=i)
net = avg_inflow - avg_outflow
running_balance += net
daily.append({
"date": d.strftime("%Y-%m-%d"),
"balance": round(running_balance, 2),
"avg_inflow": round(avg_inflow, 2),
"avg_outflow": round(avg_outflow, 2),
"negative_risk": running_balance < 0,
})
negative_days = [d["date"] for d in daily if d["negative_risk"]]
return {
"current_balance": round(current_balance, 2),
"avg_daily_inflow": round(avg_inflow, 2),
"avg_daily_outflow": round(avg_outflow, 2),
"forecast": daily,
"negative_risk_days": negative_days,
"history_days": len(hist_df),
}

View file

@ -0,0 +1,82 @@
from datetime import date, timedelta
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.db.models.user import User
from app.schemas.report import (
BudgetVsActualReport,
CashFlowReport,
CategoryBreakdownReport,
IncomeExpenseReport,
NetWorthReport,
SpendingTrendsReport,
)
from app.services import report_service
router = APIRouter(prefix="/reports", tags=["reports"])
@router.get("/net-worth", response_model=NetWorthReport)
async def net_worth_report(
months: int = Query(default=12, ge=1, le=60),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await report_service.get_net_worth_report(
db, current_user.id, current_user.base_currency, months
)
@router.get("/income-vs-expense", response_model=IncomeExpenseReport)
async def income_expense_report(
months: int = Query(default=12, ge=1, le=60),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await report_service.get_income_expense_report(db, current_user.id, months)
@router.get("/cash-flow", response_model=CashFlowReport)
async def cash_flow_report(
date_from: date = Query(default=None),
date_to: date = Query(default=None),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
today = date.today()
df = date_from or (today - timedelta(days=30))
dt = date_to or today
return await report_service.get_cash_flow_report(db, current_user.id, df, dt)
@router.get("/category-breakdown", response_model=CategoryBreakdownReport)
async def category_breakdown(
date_from: date = Query(default=None),
date_to: date = Query(default=None),
type: str = Query(default="expense"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
today = date.today()
df = date_from or date(today.year, today.month, 1)
dt = date_to or today
return await report_service.get_category_breakdown(db, current_user.id, df, dt, type)
@router.get("/budget-vs-actual", response_model=BudgetVsActualReport)
async def budget_vs_actual(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await report_service.get_budget_vs_actual(db, current_user.id)
@router.get("/spending-trends", response_model=SpendingTrendsReport)
async def spending_trends(
months: int = Query(default=6, ge=1, le=24),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
return await report_service.get_spending_trends(db, current_user.id, months)

View file

@ -0,0 +1,332 @@
import csv
import io
import mimetypes
import os
import uuid
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import get_settings
from app.core.audit import write_audit
from app.dependencies import get_current_user, get_db
from app.schemas.transaction import TransactionCreate, TransactionFilter, TransactionUpdate
from app.services.transaction_service import (
TransactionError,
create_transaction,
delete_transaction,
get_transaction,
import_csv,
list_transactions,
update_transaction,
_to_response,
)
MAX_IMPORT_FILE_BYTES = 10 * 1024 * 1024 # 10 MB
MAX_IMPORT_ROWS = 50_000
ALLOWED_MIME_TYPES = {
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
}
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".pdf"}
router = APIRouter()
@router.get("")
async def get_transactions(
account_id: uuid.UUID | None = None,
category_id: uuid.UUID | None = None,
type: str | None = None,
status: str | None = None,
date_from: str | None = None,
date_to: str | None = None,
search: str | None = None,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
from datetime import date
filters = TransactionFilter(
account_id=account_id,
category_id=category_id,
type=type,
status=status,
date_from=date.fromisoformat(date_from) if date_from else None,
date_to=date.fromisoformat(date_to) if date_to else None,
search=search,
page=page,
page_size=page_size,
)
return await list_transactions(db, user.id, filters)
@router.post("", status_code=201)
async def create(
body: TransactionCreate,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
result = await create_transaction(db, user.id, body, user.base_currency)
await write_audit(db, user_id=user.id, action="transaction_create")
await db.commit()
return result
except TransactionError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@router.get("/{txn_id}")
async def get_one(
txn_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
txn = await get_transaction(db, txn_id, user.id)
return _to_response(txn)
except TransactionError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@router.put("/{txn_id}")
async def update(
txn_id: uuid.UUID,
body: TransactionUpdate,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
result = await update_transaction(db, txn_id, user.id, body, user.base_currency)
await write_audit(db, user_id=user.id, action="transaction_update", resource_type="transaction", resource_id=txn_id)
await db.commit()
return result
except TransactionError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@router.delete("/{txn_id}", status_code=204)
async def delete(
txn_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
try:
await delete_transaction(db, txn_id, user.id)
await write_audit(db, user_id=user.id, action="transaction_delete", resource_type="transaction", resource_id=txn_id)
await db.commit()
except TransactionError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@router.post("/{txn_id}/attachments")
async def upload_attachment(
txn_id: uuid.UUID,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
settings = get_settings()
# Validate extension
filename = file.filename or "upload"
ext = Path(filename).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="Unsupported file type. Allowed: JPG, PNG, WebP, PDF")
# Verify transaction ownership
try:
txn = await get_transaction(db, txn_id, user.id)
except TransactionError as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
current_refs: list = txn.get("attachment_refs", []) if isinstance(txn, dict) else []
# Fetch raw model for JSONB mutation
from sqlalchemy import select
from app.db.models.transaction import Transaction as TxnModel
result = await db.execute(
select(TxnModel).where(TxnModel.id == txn_id, TxnModel.user_id == user.id)
)
txn_row = result.scalar_one_or_none()
if not txn_row:
raise HTTPException(status_code=404, detail="Transaction not found")
current_refs = list(txn_row.attachment_refs or [])
if len(current_refs) >= settings.max_attachments_per_txn:
raise HTTPException(status_code=400, detail=f"Maximum {settings.max_attachments_per_txn} attachments per transaction")
# Read and size-check
content = await file.read(settings.max_attachment_bytes + 1)
if len(content) > settings.max_attachment_bytes:
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
# Sniff MIME from content
import magic # python-magic
detected_mime = magic.from_buffer(content[:2048], mime=True)
if detected_mime not in ALLOWED_MIME_TYPES:
raise HTTPException(status_code=400, detail="File content does not match an allowed type (JPEG, PNG, WebP, PDF)")
# Store file
attachment_id = str(uuid.uuid4())
user_upload_dir = Path(settings.upload_dir) / str(user.id)
user_upload_dir.mkdir(parents=True, exist_ok=True)
stored_name = f"{attachment_id}{ext}"
stored_path = user_upload_dir / stored_name
stored_path.write_bytes(content)
# Update attachment_refs
ref = {
"id": attachment_id,
"filename": filename,
"mime_type": detected_mime,
"size": len(content),
"stored_name": stored_name,
}
from sqlalchemy import update as sql_update
import copy
new_refs = copy.copy(current_refs)
new_refs.append(ref)
await db.execute(
sql_update(TxnModel)
.where(TxnModel.id == txn_id)
.values(attachment_refs=new_refs)
)
await write_audit(db, user_id=user.id, action="transaction_update", resource_type="transaction", resource_id=txn_id, metadata={"attachment_added": attachment_id})
await db.commit()
return ref
@router.get("/{txn_id}/attachments/{attachment_id}")
async def download_attachment(
txn_id: uuid.UUID,
attachment_id: str,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
settings = get_settings()
from sqlalchemy import select
from app.db.models.transaction import Transaction as TxnModel
result = await db.execute(
select(TxnModel).where(TxnModel.id == txn_id, TxnModel.user_id == user.id)
)
txn_row = result.scalar_one_or_none()
if not txn_row:
raise HTTPException(status_code=404, detail="Transaction not found")
ref = next((r for r in (txn_row.attachment_refs or []) if r["id"] == attachment_id), None)
if not ref:
raise HTTPException(status_code=404, detail="Attachment not found")
path = Path(settings.upload_dir) / str(user.id) / ref["stored_name"]
if not path.exists():
raise HTTPException(status_code=404, detail="Attachment file missing")
return FileResponse(
path=str(path),
media_type=ref["mime_type"],
filename=ref["filename"],
)
@router.delete("/{txn_id}/attachments/{attachment_id}", status_code=204)
async def delete_attachment(
txn_id: uuid.UUID,
attachment_id: str,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
settings = get_settings()
from sqlalchemy import select, update as sql_update
from app.db.models.transaction import Transaction as TxnModel
result = await db.execute(
select(TxnModel).where(TxnModel.id == txn_id, TxnModel.user_id == user.id)
)
txn_row = result.scalar_one_or_none()
if not txn_row:
raise HTTPException(status_code=404, detail="Transaction not found")
refs = list(txn_row.attachment_refs or [])
ref = next((r for r in refs if r["id"] == attachment_id), None)
if not ref:
raise HTTPException(status_code=404, detail="Attachment not found")
# Delete file
path = Path(settings.upload_dir) / str(user.id) / ref["stored_name"]
try:
path.unlink(missing_ok=True)
except OSError:
pass
new_refs = [r for r in refs if r["id"] != attachment_id]
await db.execute(
sql_update(TxnModel)
.where(TxnModel.id == txn_id)
.values(attachment_refs=new_refs)
)
await write_audit(db, user_id=user.id, action="transaction_update", resource_type="transaction", resource_id=txn_id, metadata={"attachment_deleted": attachment_id})
await db.commit()
@router.post("/import")
async def import_transactions(
file: UploadFile = File(...),
account_id: uuid.UUID = Form(...),
date_col: str = Form(default="date"),
description_col: str = Form(default="description"),
amount_col: str = Form(default="amount"),
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
if not file.filename or not file.filename.lower().endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files are supported")
content = await file.read(MAX_IMPORT_FILE_BYTES + 1)
if len(content) > MAX_IMPORT_FILE_BYTES:
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
try:
text = content.decode("utf-8-sig") # handle BOM
except UnicodeDecodeError:
text = content.decode("latin-1")
reader = csv.DictReader(io.StringIO(text))
rows = []
for row in reader:
if len(rows) >= MAX_IMPORT_ROWS:
raise HTTPException(status_code=400, detail=f"File contains too many rows (max {MAX_IMPORT_ROWS:,})")
mapped = {}
# Flexible column mapping
for key, col in [("date", date_col), ("description", description_col), ("amount", amount_col)]:
val = row.get(col) or row.get(col.lower()) or row.get(col.upper())
if val is not None:
mapped[key] = val.strip()
if "date" in mapped and "amount" in mapped:
mapped.setdefault("description", "Imported transaction")
rows.append(mapped)
if not rows:
raise HTTPException(status_code=400, detail="No valid rows found. Check column names.")
result = await import_csv(db, user.id, account_id, rows, user.base_currency)
await write_audit(db, user_id=user.id, action="import_data", metadata=result)
await db.commit()
return result
@router.get("/import/template")
async def import_template():
from fastapi.responses import Response
csv_content = "date,description,amount,merchant,notes\n2026-01-15,Tesco Groceries,-45.67,Tesco,\n2026-01-14,Salary,2500.00,Employer,January salary\n"
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=import_template.csv"},
)

126
backend/app/api/v1/users.py Normal file
View file

@ -0,0 +1,126 @@
import csv
import io
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.audit import write_audit
from app.core.security import hash_password, verify_password
from app.dependencies import get_current_user, get_db
router = APIRouter()
@router.get("/me")
async def get_me(user=Depends(get_current_user)):
return {
"id": str(user.id),
"email": user.email,
"display_name": user.display_name,
"base_currency": user.base_currency,
"theme": user.theme,
"locale": user.locale,
"totp_enabled": user.totp_enabled,
"last_login_at": user.last_login_at,
"created_at": user.created_at,
}
class PasswordChangeRequest(BaseModel):
current_password: str
new_password: str = Field(..., min_length=10)
@router.post("/me/password", status_code=200)
async def change_password(
body: PasswordChangeRequest,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
if not verify_password(body.current_password, user.password_hash):
raise HTTPException(status_code=400, detail="Current password is incorrect")
user.password_hash = hash_password(body.new_password)
user.updated_at = datetime.now(timezone.utc)
await write_audit(db, user_id=user.id, action="password_change")
await db.commit()
return {"message": "Password updated successfully"}
class ProfileUpdateRequest(BaseModel):
display_name: str | None = Field(default=None, max_length=100)
base_currency: str | None = Field(default=None, min_length=3, max_length=10)
@router.put("/me", status_code=200)
async def update_profile(
body: ProfileUpdateRequest,
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
if body.display_name is not None:
user.display_name = body.display_name
if body.base_currency is not None:
user.base_currency = body.base_currency.upper()
user.updated_at = datetime.now(timezone.utc)
await db.commit()
return {"message": "Profile updated"}
@router.get("/me/export")
async def export_data(
db: AsyncSession = Depends(get_db),
user=Depends(get_current_user),
):
from app.db.models.transaction import Transaction
from app.db.models.account import Account
from app.db.models.category import Category
from app.core.security import decrypt_field
result = await db.execute(
select(Transaction, Account, Category)
.join(Account, Account.id == Transaction.account_id)
.outerjoin(Category, Category.id == Transaction.category_id)
.where(
Transaction.user_id == user.id,
Transaction.deleted_at.is_(None),
)
.order_by(Transaction.date.desc())
)
rows = result.all()
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"date", "description", "merchant", "amount", "currency",
"type", "status", "category", "account", "notes", "tags",
])
for txn, account, category in rows:
writer.writerow([
txn.date.isoformat(),
decrypt_field(txn.description_enc) or "",
decrypt_field(txn.merchant_enc) if txn.merchant_enc else "",
str(txn.amount),
txn.currency,
txn.type,
txn.status,
category.name if category else "",
decrypt_field(account.name_enc) or "",
decrypt_field(txn.notes_enc) if txn.notes_enc else "",
",".join(txn.tags or []),
])
output.seek(0)
filename = f"transactions_{datetime.now(timezone.utc).strftime('%Y%m%d')}.csv"
await write_audit(db, user_id=user.id, action="data_export")
await db.commit()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"},
)

54
backend/app/config.py Normal file
View file

@ -0,0 +1,54 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
database_url: str = "postgresql+asyncpg://finance_app:password@postgres:5432/financedb"
redis_url: str = "redis://localhost:6379/0"
encryption_key: str # 32-byte hex string
backup_passphrase: str = ""
environment: str = "production"
allow_registration: bool = False
base_currency: str = "GBP"
# JWT — keys read from /run/secrets/ at runtime
jwt_private_key_file: str = "/run/secrets/jwt_private.pem"
jwt_public_key_file: str = "/run/secrets/jwt_public.pem"
jwt_algorithm: str = "RS256"
access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 7
# Security
csrf_token_expire_hours: int = 24
max_login_attempts: int = 5
lockout_base_seconds: int = 1800 # 30 min, doubles each time
# Rate limits (requests per minute)
rate_limit_auth: int = 10
rate_limit_api: int = 300
rate_limit_predictions: int = 20
# File uploads
upload_dir: str = "/app/uploads"
max_attachment_bytes: int = 10 * 1024 * 1024 # 10 MB
max_attachments_per_txn: int = 10
# Background jobs
price_sync_interval_minutes: int = 15
fx_sync_interval_minutes: int = 60
snapshot_hour: int = 2 # 2 AM daily
backup_hour: int = 3 # 3 AM daily
ml_retrain_day: str = "sun" # weekly on Sunday
@property
def is_development(self) -> bool:
return self.environment == "development"
@lru_cache
def get_settings() -> Settings:
return Settings()

View file

38
backend/app/core/audit.py Normal file
View file

@ -0,0 +1,38 @@
"""
Append-only audit log writer.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
async def write_audit(
db: "AsyncSession",
*,
user_id: uuid.UUID | None,
action: str,
resource_type: str | None = None,
resource_id: uuid.UUID | None = None,
ip_address: str | None = None,
user_agent: str | None = None,
metadata: dict[str, Any] | None = None,
success: bool = True,
) -> None:
from app.db.models.audit_log import AuditLog
log = AuditLog(
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
meta=metadata or {},
success=success,
created_at=datetime.now(timezone.utc),
)
db.add(log)
# Note: caller is responsible for committing

View file

@ -0,0 +1,24 @@
"""
Helpers for re-encrypting all sensitive DB fields during key rotation.
"""
from app.core.security import decrypt_field, encrypt_field
def reencrypt(data: bytes, old_key_hex: str, new_key_hex: str) -> bytes:
"""Re-encrypt a bytea field from old key to new key."""
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import os
old_key = bytes.fromhex(old_key_hex)
new_key = bytes.fromhex(new_key_hex)
# Decrypt with old key
iv = data[:12]
ciphertext_with_tag = data[12:]
aesgcm_old = AESGCM(old_key)
plaintext = aesgcm_old.decrypt(iv, ciphertext_with_tag, None)
# Encrypt with new key
new_iv = os.urandom(12)
aesgcm_new = AESGCM(new_key)
return new_iv + aesgcm_new.encrypt(new_iv, plaintext, None)

View file

@ -0,0 +1,144 @@
"""
AES-256-GCM key rotation: decrypt all encrypted fields with OLD key, re-encrypt with NEW key.
Run while the application is STOPPED:
docker compose exec \
-e ENCRYPTION_KEY="$OLD_ENCRYPTION_KEY" \
-e NEW_ENCRYPTION_KEY="$NEW_ENCRYPTION_KEY" \
backend python -m app.core.key_rotation
On success, update ENCRYPTION_KEY in .env to the new value and restart.
"""
import os
import sys
import logging
from typing import Callable
import psycopg2
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
logging.basicConfig(level=logging.INFO, format="[rotate] %(message)s")
log = logging.getLogger(__name__)
def _make_cipher(key_hex: str) -> AESGCM:
key = bytes.fromhex(key_hex)
if len(key) != 32:
raise ValueError("Key must be 32 bytes (64 hex chars)")
return AESGCM(key)
def _decrypt(cipher: AESGCM, data: bytes) -> bytes:
"""Return plaintext bytes given IV(12)||ciphertext+tag."""
if not data:
return b""
return cipher.decrypt(data[:12], data[12:], None)
def _encrypt(cipher: AESGCM, plaintext: bytes) -> bytes:
"""Encrypt plaintext bytes → IV(12)||ciphertext+tag."""
if not plaintext:
return b""
iv = os.urandom(12)
return iv + cipher.encrypt(iv, plaintext, None)
def _reencrypt(old: AESGCM, new: AESGCM, data: bytes | None) -> bytes | None:
if not data:
return data
plaintext = _decrypt(old, data)
return _encrypt(new, plaintext)
def _reencrypt_hex(old: AESGCM, new: AESGCM, hex_str: str | None) -> str | None:
"""For fields stored as hex strings (e.g. totp_secret_enc)."""
if not hex_str:
return hex_str
data = bytes.fromhex(hex_str)
plaintext = _decrypt(old, data)
return _encrypt(new, plaintext).hex()
def rotate(db_url: str, old_key_hex: str, new_key_hex: str) -> None:
old = _make_cipher(old_key_hex)
new = _make_cipher(new_key_hex)
conn = psycopg2.connect(db_url)
conn.autocommit = False
cur = conn.cursor()
try:
# ------------------------------------------------------------------ accounts
cur.execute("SELECT id, name, institution, notes FROM accounts WHERE deleted_at IS NULL")
rows = cur.fetchall()
log.info(f"Rotating {len(rows)} account row(s)…")
for row_id, name, institution, notes in rows:
cur.execute(
"UPDATE accounts SET name=%s, institution=%s, notes=%s WHERE id=%s",
(
_reencrypt(old, new, bytes(name) if name else None),
_reencrypt(old, new, bytes(institution) if institution else None),
_reencrypt(old, new, bytes(notes) if notes else None),
row_id,
),
)
# -------------------------------------------------------------- transactions
cur.execute(
"SELECT id, description, merchant, notes FROM transactions WHERE deleted_at IS NULL"
)
rows = cur.fetchall()
log.info(f"Rotating {len(rows)} transaction row(s)…")
for row_id, description, merchant, notes in rows:
cur.execute(
"UPDATE transactions SET description=%s, merchant=%s, notes=%s WHERE id=%s",
(
_reencrypt(old, new, bytes(description) if description else None),
_reencrypt(old, new, bytes(merchant) if merchant else None),
_reencrypt(old, new, bytes(notes) if notes else None),
row_id,
),
)
# -------------------------------------------------------------------- users
cur.execute("SELECT id, totp_secret FROM users WHERE deleted_at IS NULL")
rows = cur.fetchall()
log.info(f"Rotating {len(rows)} user row(s)…")
for row_id, totp_secret in rows:
cur.execute(
"UPDATE users SET totp_secret=%s WHERE id=%s",
(_reencrypt_hex(old, new, totp_secret), row_id),
)
conn.commit()
log.info("Key rotation complete — all fields re-encrypted.")
log.info("Now update ENCRYPTION_KEY in .env and restart the application.")
except Exception:
conn.rollback()
log.exception("Rotation FAILED — rolled back, no data changed.")
sys.exit(1)
finally:
cur.close()
conn.close()
if __name__ == "__main__":
old_key = os.environ.get("ENCRYPTION_KEY", "")
new_key = os.environ.get("NEW_ENCRYPTION_KEY", "")
db_url = os.environ.get("DATABASE_URL", "").replace("postgresql+asyncpg://", "postgresql://")
if not old_key:
log.error("ENCRYPTION_KEY (current/old key) is not set")
sys.exit(1)
if not new_key:
log.error("NEW_ENCRYPTION_KEY is not set")
sys.exit(1)
if not db_url:
log.error("DATABASE_URL is not set")
sys.exit(1)
if old_key == new_key:
log.error("NEW_ENCRYPTION_KEY is the same as ENCRYPTION_KEY — nothing to do")
sys.exit(1)
rotate(db_url, old_key, new_key)

View file

@ -0,0 +1,81 @@
"""
Security middleware: headers, CSRF double-submit, request ID, RLS user context.
"""
import uuid
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
SECURITY_HEADERS = {
"X-Frame-Options": "DENY",
"X-Content-Type-Options": "nosniff",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "camera=(), microphone=(), geolocation=()",
"Cross-Origin-Opener-Policy": "same-origin",
"Cross-Origin-Resource-Policy": "same-origin",
"Strict-Transport-Security": "max-age=63072000; includeSubDomains",
"Content-Security-Policy": (
"default-src 'self'; "
"script-src 'self'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data:; "
"connect-src 'self'; "
"form-action 'self'; "
"frame-ancestors 'none'"
),
}
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response: Response = await call_next(request)
for header, value in SECURITY_HEADERS.items():
response.headers[header] = value
response.headers["X-Request-ID"] = str(uuid.uuid4())
return response
class CSRFMiddleware(BaseHTTPMiddleware):
"""Double-submit cookie CSRF protection for mutating requests."""
EXEMPT_PATHS = {"/api/v1/auth/login", "/api/v1/auth/refresh", "/api/v1/auth/register", "/health"}
async def dispatch(self, request: Request, call_next):
# Always set the csrf_token cookie if it doesn't exist yet
existing_csrf = request.cookies.get("csrf_token")
if request.method in SAFE_METHODS:
response: Response = await call_next(request)
if not existing_csrf:
token = str(uuid.uuid4())
response.set_cookie(
"csrf_token", token,
httponly=False, # must be readable by JS
samesite="lax",
secure=False, # set True if TLS is terminated at this service
)
return response
if request.url.path in self.EXEMPT_PATHS:
response = await call_next(request)
if not existing_csrf:
token = str(uuid.uuid4())
response.set_cookie("csrf_token", token, httponly=False, samesite="lax", secure=False)
return response
if request.url.path in {"/api/v1/auth/login", "/api/v1/auth/login/totp"}:
return await call_next(request)
cookie_token = existing_csrf
header_token = request.headers.get("X-CSRF-Token")
if not cookie_token or not header_token or cookie_token != header_token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing or invalid"},
)
return await call_next(request)

View file

@ -0,0 +1,28 @@
"""
Redis sliding window rate limiter.
"""
import time
from redis.asyncio import Redis
async def is_rate_limited(
redis: Redis,
key: str,
limit: int,
window_seconds: int = 60,
) -> tuple[bool, int]:
"""
Returns (is_limited, requests_remaining).
Uses a sorted set with timestamps as scores for sliding window.
"""
now = time.time()
window_start = now - window_seconds
pipe = redis.pipeline()
pipe.zremrangebyscore(key, 0, window_start)
pipe.zadd(key, {str(now): now})
pipe.zcard(key)
pipe.expire(key, window_seconds + 1)
results = await pipe.execute()
count = results[2]
remaining = max(0, limit - count)
return count > limit, remaining

View file

@ -0,0 +1,197 @@
"""
Cryptographic primitives: Argon2id password hashing, RS256 JWT, AES-256-GCM field encryption, TOTP.
"""
import base64
import os
import secrets
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
import pyotp
import qrcode
import qrcode.image.svg
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError, VerificationError, InvalidHashError
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from jose import JWTError, jwt
from app.config import get_settings
# Argon2id — OWASP recommended parameters
_ph = PasswordHasher(
time_cost=3,
memory_cost=65536,
parallelism=4,
hash_len=32,
salt_len=16,
)
# ---------------------------------------------------------------------------
# Password hashing
# ---------------------------------------------------------------------------
def hash_password(password: str) -> str:
return _ph.hash(password)
def verify_password(password: str, hashed: str) -> bool:
try:
return _ph.verify(hashed, password)
except (VerifyMismatchError, VerificationError, InvalidHashError):
return False
def password_needs_rehash(hashed: str) -> bool:
return _ph.check_needs_rehash(hashed)
# ---------------------------------------------------------------------------
# JWT (RS256)
# ---------------------------------------------------------------------------
def _load_private_key() -> str:
settings = get_settings()
return Path(settings.jwt_private_key_file).read_text()
def _load_public_key() -> str:
settings = get_settings()
return Path(settings.jwt_public_key_file).read_text()
def create_access_token(subject: str, extra: dict[str, Any] | None = None) -> str:
settings = get_settings()
now = datetime.now(timezone.utc)
payload: dict[str, Any] = {
"sub": subject,
"iat": now,
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
"type": "access",
}
if extra:
payload.update(extra)
return jwt.encode(payload, _load_private_key(), algorithm=settings.jwt_algorithm)
def create_refresh_token(subject: str) -> str:
settings = get_settings()
now = datetime.now(timezone.utc)
payload: dict[str, Any] = {
"sub": subject,
"iat": now,
"exp": now + timedelta(days=settings.refresh_token_expire_days),
"type": "refresh",
"jti": secrets.token_hex(16),
}
return jwt.encode(payload, _load_private_key(), algorithm=settings.jwt_algorithm)
def decode_token(token: str, token_type: str = "access") -> dict[str, Any]:
settings = get_settings()
payload = jwt.decode(
token,
_load_public_key(),
algorithms=[settings.jwt_algorithm],
options={"verify_exp": True},
)
if payload.get("type") != token_type:
raise JWTError("Invalid token type")
return payload
# ---------------------------------------------------------------------------
# AES-256-GCM field encryption
# ---------------------------------------------------------------------------
def _get_aes_key() -> bytes:
"""Derive 32-byte key from hex ENCRYPTION_KEY env var."""
settings = get_settings()
key_hex = settings.encryption_key
key = bytes.fromhex(key_hex)
if len(key) != 32:
raise ValueError("ENCRYPTION_KEY must be a 32-byte hex string (64 hex chars)")
return key
def encrypt_field(plaintext: str) -> bytes:
"""Encrypt a string field. Returns IV(12) || ciphertext || tag(16) as bytes."""
if not plaintext:
return b""
key = _get_aes_key()
iv = os.urandom(12)
aesgcm = AESGCM(key)
ciphertext_with_tag = aesgcm.encrypt(iv, plaintext.encode(), None)
return iv + ciphertext_with_tag
def decrypt_field(data: bytes) -> str:
"""Decrypt bytes produced by encrypt_field."""
if not data:
return ""
key = _get_aes_key()
iv = data[:12]
ciphertext_with_tag = data[12:]
aesgcm = AESGCM(key)
return aesgcm.decrypt(iv, ciphertext_with_tag, None).decode()
def encrypt_field_b64(plaintext: str) -> str:
"""Convenience: encrypt and return base64 string (for JSON/text contexts)."""
return base64.b64encode(encrypt_field(plaintext)).decode()
def decrypt_field_b64(data: str) -> str:
return decrypt_field(base64.b64decode(data))
# ---------------------------------------------------------------------------
# TOTP (RFC 6238)
# ---------------------------------------------------------------------------
def generate_totp_secret() -> str:
return pyotp.random_base32()
def get_totp_uri(secret: str, email: str) -> str:
return pyotp.totp.TOTP(secret).provisioning_uri(
name=email, issuer_name="Finance Tracker"
)
def generate_totp_qr_png(secret: str, email: str) -> bytes:
uri = get_totp_uri(secret, email)
img = qrcode.make(uri)
from io import BytesIO
buf = BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def verify_totp(secret: str, code: str) -> bool:
totp = pyotp.TOTP(secret)
return totp.verify(code, valid_window=1)
# ---------------------------------------------------------------------------
# CSRF token
# ---------------------------------------------------------------------------
def generate_csrf_token() -> str:
return secrets.token_hex(32)
# ---------------------------------------------------------------------------
# Misc helpers
# ---------------------------------------------------------------------------
def generate_backup_codes(count: int = 8) -> list[str]:
"""Generate one-time backup codes."""
return [secrets.token_hex(4).upper() + "-" + secrets.token_hex(4).upper() for _ in range(count)]
def hash_token(token: str) -> str:
"""SHA-256 hash of a bearer token for DB storage."""
import hashlib
return hashlib.sha256(token.encode()).hexdigest()

View file

28
backend/app/db/base.py Normal file
View file

@ -0,0 +1,28 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import get_settings
class Base(DeclarativeBase):
pass
def create_engine():
settings = get_settings()
return create_async_engine(
settings.database_url,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
echo=settings.is_development,
)
def create_session_factory(engine):
return async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)

View file

@ -0,0 +1,19 @@
from app.db.models.user import User
from app.db.models.session import Session
from app.db.models.account import Account
from app.db.models.category import Category
from app.db.models.transaction import Transaction
from app.db.models.budget import Budget
from app.db.models.asset import Asset
from app.db.models.asset_price import AssetPrice
from app.db.models.investment_holding import InvestmentHolding
from app.db.models.investment_transaction import InvestmentTransaction
from app.db.models.currency import Currency, ExchangeRate
from app.db.models.net_worth_snapshot import NetWorthSnapshot
from app.db.models.audit_log import AuditLog
__all__ = [
"User", "Session", "Account", "Category", "Transaction", "Budget",
"Asset", "AssetPrice", "InvestmentHolding", "InvestmentTransaction",
"Currency", "ExchangeRate", "NetWorthSnapshot", "AuditLog",
]

View file

@ -0,0 +1,36 @@
import uuid
from datetime import datetime
from decimal import Decimal
from sqlalchemy import Boolean, DateTime, ForeignKey, LargeBinary, Numeric, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class Account(Base):
__tablename__ = "accounts"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
name_enc: Mapped[bytes] = mapped_column("name", LargeBinary, nullable=False)
institution_enc: Mapped[bytes | None] = mapped_column("institution", LargeBinary, nullable=True)
type: Mapped[str] = mapped_column(String(30), nullable=False)
currency: Mapped[str] = mapped_column(String(10), nullable=False)
current_balance: Mapped[Decimal] = mapped_column(Numeric(20, 8), default=0, nullable=False)
credit_limit: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
interest_rate: Mapped[Decimal | None] = mapped_column(Numeric(8, 4), nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
include_in_net_worth: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
color: Mapped[str] = mapped_column(String(7), default="#6366f1", nullable=False)
icon: Mapped[str | None] = mapped_column(Text, nullable=True)
notes_enc: Mapped[bytes | None] = mapped_column("notes", LargeBinary, nullable=True)
meta: Mapped[dict] = mapped_column(JSONB, default=dict, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
user: Mapped["User"] = relationship(back_populates="accounts", lazy="noload") # type: ignore[name-defined]
transactions: Mapped[list["Transaction"]] = relationship(foreign_keys="Transaction.account_id", back_populates="account", lazy="noload") # type: ignore[name-defined]
holdings: Mapped[list["InvestmentHolding"]] = relationship(back_populates="account", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,32 @@
import uuid
from datetime import datetime
from decimal import Decimal
from sqlalchemy import Boolean, DateTime, Numeric, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
symbol: Mapped[str] = mapped_column(Text, nullable=False, index=True)
name: Mapped[str] = mapped_column(Text, nullable=False)
type: Mapped[str] = mapped_column(String(30), nullable=False) # stock|etf|mutual_fund|bond|crypto|commodity|other
currency: Mapped[str] = mapped_column(String(10), nullable=False)
exchange: Mapped[str | None] = mapped_column(Text, nullable=True)
isin: Mapped[str | None] = mapped_column(String(12), nullable=True)
data_source: Mapped[str] = mapped_column(String(30), default="yahoo_finance", nullable=False)
data_source_id: Mapped[str | None] = mapped_column(Text, nullable=True)
last_price: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
last_price_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
price_change_24h: Mapped[Decimal | None] = mapped_column(Numeric(10, 4), nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
prices: Mapped[list["AssetPrice"]] = relationship(back_populates="asset", lazy="noload") # type: ignore[name-defined]
holdings: Mapped[list["InvestmentHolding"]] = relationship(back_populates="asset", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,25 @@
import uuid
from datetime import date, datetime
from decimal import Decimal
from sqlalchemy import Date, DateTime, ForeignKey, Numeric
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class AssetPrice(Base):
__tablename__ = "asset_prices"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
asset_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False, index=True)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
open: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
high: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
low: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
close: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
volume: Mapped[Decimal | None] = mapped_column(Numeric(30, 8), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
asset: Mapped["Asset"] = relationship(back_populates="prices", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,23 @@
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import INET, JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.db.base import Base
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True)
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
resource_type: Mapped[str | None] = mapped_column(Text, nullable=True)
resource_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
ip_address: Mapped[str | None] = mapped_column(INET, nullable=True)
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
meta: Mapped[dict] = mapped_column("metadata", JSONB, default=dict, nullable=False)
success: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)

View file

@ -0,0 +1,30 @@
import uuid
from datetime import date, datetime
from decimal import Decimal
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class Budget(Base):
__tablename__ = "budgets"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
category_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("categories.id"), nullable=False)
name: Mapped[str] = mapped_column(Text, nullable=False)
amount: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
currency: Mapped[str] = mapped_column(String(10), nullable=False)
period: Mapped[str] = mapped_column(String(20), nullable=False) # weekly|monthly|quarterly|yearly
start_date: Mapped[date] = mapped_column(Date, nullable=False)
end_date: Mapped[date | None] = mapped_column(Date, nullable=True)
rollover: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
alert_threshold: Mapped[Decimal] = mapped_column(Numeric(5, 2), default=80, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
category: Mapped["Category"] = relationship(back_populates="budgets", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,26 @@
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class Category(Base):
__tablename__ = "categories"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=True, index=True)
name: Mapped[str] = mapped_column(Text, nullable=False)
parent_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("categories.id"), nullable=True)
type: Mapped[str] = mapped_column(String(20), nullable=False) # income | expense | transfer
icon: Mapped[str | None] = mapped_column(Text, nullable=True)
color: Mapped[str | None] = mapped_column(String(7), nullable=True)
is_system: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
sort_order: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
children: Mapped[list["Category"]] = relationship(lazy="noload")
budgets: Mapped[list["Budget"]] = relationship(back_populates="category", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,31 @@
import uuid
from datetime import datetime
from decimal import Decimal
from sqlalchemy import Boolean, DateTime, Integer, Numeric, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.db.base import Base
class Currency(Base):
__tablename__ = "currencies"
code: Mapped[str] = mapped_column(String(10), primary_key=True)
name: Mapped[str] = mapped_column(Text, nullable=False)
symbol: Mapped[str] = mapped_column(String(5), nullable=False)
is_crypto: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
decimal_places: Mapped[int] = mapped_column(Integer, default=2, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
class ExchangeRate(Base):
__tablename__ = "exchange_rates"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
base_currency: Mapped[str] = mapped_column(String(10), nullable=False, index=True)
quote_currency: Mapped[str] = mapped_column(String(10), nullable=False, index=True)
rate: Mapped[Decimal] = mapped_column(Numeric(20, 10), nullable=False)
source: Mapped[str] = mapped_column(String(50), nullable=False)
fetched_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)

View file

@ -0,0 +1,27 @@
import uuid
from datetime import datetime
from decimal import Decimal
from sqlalchemy import DateTime, ForeignKey, Numeric, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class InvestmentHolding(Base):
__tablename__ = "investment_holdings"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id"), nullable=False)
asset_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("assets.id"), nullable=False)
quantity: Mapped[Decimal] = mapped_column(Numeric(30, 10), default=0, nullable=False)
avg_cost_basis: Mapped[Decimal] = mapped_column(Numeric(20, 8), default=0, nullable=False)
currency: Mapped[str] = mapped_column(String(10), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
account: Mapped["Account"] = relationship(back_populates="holdings", lazy="noload") # type: ignore[name-defined]
asset: Mapped["Asset"] = relationship(back_populates="holdings", lazy="noload") # type: ignore[name-defined]
investment_transactions: Mapped[list["InvestmentTransaction"]] = relationship(back_populates="holding", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,29 @@
import uuid
from datetime import date, datetime
from decimal import Decimal
from sqlalchemy import Date, DateTime, ForeignKey, LargeBinary, Numeric, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class InvestmentTransaction(Base):
__tablename__ = "investment_transactions"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
holding_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("investment_holdings.id"), nullable=False, index=True)
transaction_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("transactions.id"), nullable=True)
type: Mapped[str] = mapped_column(String(20), nullable=False) # buy|sell|dividend|split|merger|transfer_in|transfer_out|fee
quantity: Mapped[Decimal] = mapped_column(Numeric(30, 10), nullable=False)
price: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
fees: Mapped[Decimal] = mapped_column(Numeric(20, 8), default=0, nullable=False)
total_amount: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
currency: Mapped[str] = mapped_column(String(10), nullable=False)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
notes_enc: Mapped[bytes | None] = mapped_column("notes", LargeBinary, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
holding: Mapped["InvestmentHolding"] = relationship(back_populates="investment_transactions", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,23 @@
import uuid
from datetime import date, datetime
from decimal import Decimal
from sqlalchemy import Date, DateTime, ForeignKey, Numeric, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.db.base import Base
class NetWorthSnapshot(Base):
__tablename__ = "net_worth_snapshots"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
date: Mapped[date] = mapped_column(Date, nullable=False)
total_assets: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
total_liabilities: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
net_worth: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
base_currency: Mapped[str] = mapped_column(String(10), nullable=False)
breakdown: Mapped[dict] = mapped_column(JSONB, default=dict, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)

View file

@ -0,0 +1,24 @@
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, Text
from sqlalchemy.dialects.postgresql import INET, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class Session(Base):
__tablename__ = "sessions"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
token_hash: Mapped[str] = mapped_column(Text, unique=True, nullable=False, index=True)
ip_address: Mapped[str | None] = mapped_column(INET, nullable=True)
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
last_active_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
user: Mapped["User"] = relationship(back_populates="sessions", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,42 @@
import uuid
from datetime import date, datetime
from decimal import Decimal
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, LargeBinary, Numeric, String, Text
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class Transaction(Base):
__tablename__ = "transactions"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id"), nullable=False, index=True)
transfer_account_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id"), nullable=True)
category_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("categories.id"), nullable=True, index=True)
type: Mapped[str] = mapped_column(String(20), nullable=False) # income|expense|transfer|investment
status: Mapped[str] = mapped_column(String(20), default="cleared", nullable=False)
amount: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
amount_base: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
currency: Mapped[str] = mapped_column(String(10), nullable=False)
base_currency: Mapped[str] = mapped_column(String(10), nullable=False)
exchange_rate: Mapped[Decimal | None] = mapped_column(Numeric(20, 10), nullable=True)
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
description_enc: Mapped[bytes] = mapped_column("description", LargeBinary, nullable=False)
merchant_enc: Mapped[bytes | None] = mapped_column("merchant", LargeBinary, nullable=True)
notes_enc: Mapped[bytes | None] = mapped_column("notes", LargeBinary, nullable=True)
tags: Mapped[list[str]] = mapped_column(ARRAY(Text), default=list, nullable=False)
is_recurring: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
recurring_rule: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
attachment_refs: Mapped[list] = mapped_column(JSONB, default=list, nullable=False)
import_hash: Mapped[str | None] = mapped_column(Text, nullable=True, index=True)
meta: Mapped[dict] = mapped_column(JSONB, default=dict, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
account: Mapped["Account"] = relationship(foreign_keys=[account_id], back_populates="transactions", lazy="noload") # type: ignore[name-defined]
category: Mapped["Category | None"] = relationship(lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,33 @@
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.dialects.postgresql import INET, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class User(Base):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
email: Mapped[str] = mapped_column(Text, unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(Text, nullable=False)
totp_secret_enc: Mapped[bytes | None] = mapped_column("totp_secret", type_=String, nullable=True)
totp_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
totp_backup_codes_enc: Mapped[str | None] = mapped_column("totp_backup_codes", Text, nullable=True)
display_name: Mapped[str] = mapped_column(Text, nullable=False)
base_currency: Mapped[str] = mapped_column(String(10), default="GBP", nullable=False)
theme: Mapped[str] = mapped_column(String(20), default="dark", nullable=False)
locale: Mapped[str] = mapped_column(String(20), default="en-GB", nullable=False)
failed_login_attempts: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
locked_until: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
last_login_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
last_login_ip: Mapped[str | None] = mapped_column(INET, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
accounts: Mapped[list["Account"]] = relationship(back_populates="user", lazy="noload") # type: ignore[name-defined]
sessions: Mapped[list["Session"]] = relationship(back_populates="user", lazy="noload") # type: ignore[name-defined]

View file

@ -0,0 +1,92 @@
"""
FastAPI dependency injection: DB session, Redis, current authenticated user.
"""
from __future__ import annotations
import uuid
from typing import AsyncGenerator
from fastapi import Cookie, Depends, HTTPException, Request, status
from jose import JWTError
from redis.asyncio import Redis
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decode_token, hash_token
from app.db.models.session import Session
from app.db.models.user import User
# These are set by main.py lifespan
_session_factory = None
_redis_client: Redis | None = None
def set_session_factory(factory):
global _session_factory
_session_factory = factory
def get_session_factory():
return _session_factory
def set_redis_client(client: Redis):
global _redis_client
_redis_client = client
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with _session_factory() as session:
yield session
async def get_redis() -> Redis:
return _redis_client
def _extract_bearer(request: Request) -> str | None:
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:]
return None
async def get_current_user(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
token = _extract_bearer(request)
if not token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
try:
payload = decode_token(token, token_type="access")
user_id = uuid.UUID(payload["sub"])
except (JWTError, ValueError, KeyError):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
token_hash = hash_token(token)
from datetime import datetime, timezone
result = await db.execute(
select(Session).where(
Session.user_id == user_id,
Session.token_hash == token_hash,
Session.revoked_at.is_(None),
Session.expires_at > datetime.now(timezone.utc),
)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Session expired or revoked")
result = await db.execute(
select(User).where(User.id == user_id, User.deleted_at.is_(None))
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
# Set RLS context so PostgreSQL RLS policies apply
await db.execute(text(f"SET LOCAL app.current_user_id = '{user_id}'"))
return user

88
backend/app/main.py Normal file
View file

@ -0,0 +1,88 @@
"""
FastAPI application factory with lifespan management.
"""
from contextlib import asynccontextmanager
import structlog
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from redis.asyncio import Redis, from_url
from app.config import get_settings
from app.core.middleware import CSRFMiddleware, SecurityHeadersMiddleware
from app.db.base import create_engine, create_session_factory
from app.dependencies import set_redis_client, set_session_factory
logger = structlog.get_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = get_settings()
# Database
engine = create_engine()
session_factory = create_session_factory(engine)
set_session_factory(session_factory)
# Redis
redis: Redis = from_url(settings.redis_url, decode_responses=False)
set_redis_client(redis)
# Seed system categories if needed
from app.services.category_service import seed_system_categories
async with session_factory() as db:
await seed_system_categories(db)
await db.commit()
# Background scheduler
from app.workers.scheduler import start_scheduler, stop_scheduler
await start_scheduler()
logger.info("startup_complete", env=settings.environment)
yield
await stop_scheduler()
await redis.aclose()
await engine.dispose()
logger.info("shutdown_complete")
def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(
title="Finance Tracker",
version="0.1.0",
docs_url="/docs" if settings.is_development else None,
redoc_url="/redoc" if settings.is_development else None,
openapi_url="/openapi.json" if settings.is_development else None,
lifespan=lifespan,
)
# CORS — only allow same origin in production
origins = ["http://localhost:5173"] if settings.is_development else []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(CSRFMiddleware)
# Health check (no auth required)
@app.get("/health")
async def health():
return {"status": "ok"}
# API routers
from app.api.router import router
app.include_router(router, prefix="/api/v1")
return app
app = create_app()

View file

View file

@ -0,0 +1,119 @@
from __future__ import annotations
import pandas as pd
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
async def get_monthly_category_spending(db: AsyncSession, user_id: str) -> pd.DataFrame:
result = await db.execute(text("""
SELECT
COALESCE(t.category_id::text, 'uncategorised') AS category_id,
COALESCE(c.name, 'Uncategorised') AS category_name,
DATE_TRUNC('month', t.date)::date AS ds,
SUM(ABS(t.amount))::float AS y
FROM transactions t
LEFT JOIN categories c ON c.id = t.category_id
WHERE t.user_id = CAST(:uid AS uuid)
AND t.type = 'expense'
AND t.deleted_at IS NULL
AND t.status != 'void'
GROUP BY t.category_id, c.name, DATE_TRUNC('month', t.date)
ORDER BY ds ASC
"""), {"uid": str(user_id)})
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["category_id", "category_name", "ds", "y"])
df = pd.DataFrame(rows, columns=["category_id", "category_name", "ds", "y"])
df["ds"] = pd.to_datetime(df["ds"])
df["y"] = df["y"].astype(float)
return df
async def get_monthly_net_worth(db: AsyncSession, user_id: str) -> pd.DataFrame:
result = await db.execute(text("""
SELECT date::text AS ds, net_worth::float AS y
FROM net_worth_snapshots
WHERE user_id = CAST(:uid AS uuid)
ORDER BY date ASC
"""), {"uid": str(user_id)})
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["ds", "y"])
df = pd.DataFrame(rows, columns=["ds", "y"])
df["ds"] = pd.to_datetime(df["ds"])
df["y"] = df["y"].astype(float)
# Resample to monthly end, keeping last value
df = df.set_index("ds").resample("ME").last().dropna().reset_index()
df.columns = ["ds", "y"]
return df
async def get_current_month_spending(db: AsyncSession, user_id: str) -> pd.DataFrame:
result = await db.execute(text("""
SELECT
COALESCE(t.category_id::text, 'uncategorised') AS category_id,
COALESCE(c.name, 'Uncategorised') AS category_name,
SUM(ABS(t.amount))::float AS spent
FROM transactions t
LEFT JOIN categories c ON c.id = t.category_id
WHERE t.user_id = CAST(:uid AS uuid)
AND t.type = 'expense'
AND t.deleted_at IS NULL
AND t.status != 'void'
AND DATE_TRUNC('month', t.date) = DATE_TRUNC('month', CURRENT_DATE)
GROUP BY t.category_id, c.name
"""), {"uid": str(user_id)})
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["category_id", "category_name", "spent"])
df = pd.DataFrame(rows, columns=["category_id", "category_name", "spent"])
df["spent"] = df["spent"].astype(float)
return df
async def get_portfolio_monthly_returns(db: AsyncSession, user_id: str) -> pd.DataFrame:
"""Monthly close prices for each asset in user's portfolio."""
result = await db.execute(text("""
SELECT
a.symbol,
DATE_TRUNC('month', ap.date)::date AS month,
(ARRAY_AGG(ap.close ORDER BY ap.date DESC))[1]::float AS close
FROM investment_holdings h
JOIN assets a ON a.id = h.asset_id
JOIN asset_prices ap ON ap.asset_id = h.asset_id
WHERE h.user_id = CAST(:uid AS uuid)
AND h.deleted_at IS NULL
GROUP BY a.symbol, DATE_TRUNC('month', ap.date)
ORDER BY a.symbol, month ASC
"""), {"uid": str(user_id)})
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["symbol", "month", "close"])
df = pd.DataFrame(rows, columns=["symbol", "month", "close"])
df["month"] = pd.to_datetime(df["month"])
df["close"] = df["close"].astype(float)
return df
async def get_daily_cash_flow(db: AsyncSession, user_id: str, days: int = 90) -> pd.DataFrame:
result = await db.execute(text("""
SELECT
t.date::date AS ds,
SUM(CASE WHEN t.amount > 0 THEN t.amount ELSE 0 END)::float AS inflow,
SUM(CASE WHEN t.amount < 0 THEN ABS(t.amount) ELSE 0 END)::float AS outflow
FROM transactions t
WHERE t.user_id = CAST(:uid AS uuid)
AND t.deleted_at IS NULL
AND t.status != 'void'
AND t.type IN ('income', 'expense')
AND t.date >= CURRENT_DATE - :days
GROUP BY t.date
ORDER BY t.date ASC
"""), {"uid": str(user_id), "days": days})
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["ds", "inflow", "outflow"])
df = pd.DataFrame(rows, columns=["ds", "inflow", "outflow"])
df["ds"] = pd.to_datetime(df["ds"])
return df

View file

@ -0,0 +1,135 @@
from __future__ import annotations
from datetime import date
from dateutil.relativedelta import relativedelta
import numpy as np
import pandas as pd
DEFAULT_MU = 0.07 / 12 # 7% annual expected return, monthly
DEFAULT_SIGMA = 0.15 / (12 ** 0.5) # 15% annual vol, monthly
DT = 1.0 / 12
def _project_months(from_date: date, n: int) -> list[str]:
d = from_date.replace(day=1)
return [(d + relativedelta(months=i + 1)).strftime("%Y-%m") for i in range(n)]
def run_monte_carlo(
prices_df: pd.DataFrame,
holdings: list[dict],
years: int = 5,
n_sims: int = 1000,
annual_contribution: float = 0.0,
) -> dict:
"""
prices_df: columns [symbol, month, close]
holdings: [{"symbol": str, "quantity": float, "current_value": float}]
Returns percentile paths and summary stats.
"""
n_months = years * 12
today = date.today()
future_dates = _project_months(today, n_months)
monthly_contribution = annual_contribution / 12.0
symbols = [h["symbol"] for h in holdings]
current_values = np.array([float(h.get("current_value") or 0) for h in holdings])
total_value = float(current_values.sum())
if total_value <= 0:
return {
"dates": future_dates,
"percentiles": {},
"current_value": 0.0,
"expected_value": 0.0,
"probability_of_gain": 0.5,
"insufficient_data": True,
}
# Compute per-asset parameters from price history
n_assets = len(symbols)
mus = np.full(n_assets, DEFAULT_MU)
sigmas = np.full(n_assets, DEFAULT_SIGMA)
corr = np.eye(n_assets)
if not prices_df.empty:
for i, sym in enumerate(symbols):
sym_prices = prices_df[prices_df["symbol"] == sym].sort_values("month")
if len(sym_prices) >= 3:
closes = sym_prices["close"].values.astype(float)
log_rets = np.diff(np.log(closes[closes > 0]))
if len(log_rets) >= 2:
mus[i] = float(np.mean(log_rets))
sigmas[i] = float(np.std(log_rets))
# Build correlation matrix from overlapping return series
if n_assets > 1:
ret_series = {}
for sym in symbols:
sym_prices = prices_df[prices_df["symbol"] == sym].sort_values("month")
if len(sym_prices) >= 3:
closes = sym_prices["close"].values.astype(float)
log_rets = np.diff(np.log(closes[closes > 0]))
ret_series[sym] = log_rets
if len(ret_series) == n_assets:
min_len = min(len(v) for v in ret_series.values())
if min_len >= 3:
matrix = np.array([v[-min_len:] for v in ret_series.values()])
corr = np.corrcoef(matrix)
corr = np.clip(corr, -0.99, 0.99)
np.fill_diagonal(corr, 1.0)
# Covariance matrix and Cholesky decomposition
cov = np.outer(sigmas, sigmas) * corr
try:
L = np.linalg.cholesky(cov)
except np.linalg.LinAlgError:
# Fall back to diagonal covariance
L = np.diag(sigmas)
# Portfolio weights
weights = current_values / total_value
# GBM simulation
rng = np.random.default_rng(42)
portfolio_paths = np.zeros((n_sims, n_months))
for sim in range(n_sims):
asset_values = current_values.copy()
for t in range(n_months):
Z = rng.standard_normal(n_assets)
corr_Z = L @ Z
# GBM step for each asset
asset_values = asset_values * np.exp(
(mus - 0.5 * sigmas ** 2) * DT + sigmas * np.sqrt(DT) * corr_Z
)
port_val = float(asset_values.sum()) + monthly_contribution * (t + 1)
portfolio_paths[sim, t] = max(0.0, port_val)
# Compute percentile paths
pcts = {
"p10": np.percentile(portfolio_paths, 10, axis=0),
"p25": np.percentile(portfolio_paths, 25, axis=0),
"p50": np.percentile(portfolio_paths, 50, axis=0),
"p75": np.percentile(portfolio_paths, 75, axis=0),
"p90": np.percentile(portfolio_paths, 90, axis=0),
}
final_values = portfolio_paths[:, -1]
prob_gain = float(np.mean(final_values > total_value))
expected_value = float(np.median(final_values))
return {
"dates": future_dates,
"percentiles": {
k: [{"date": d, "value": round(float(v), 2)} for d, v in zip(future_dates, arr)]
for k, arr in pcts.items()
},
"current_value": round(total_value, 2),
"expected_value": round(expected_value, 2),
"probability_of_gain": round(prob_gain, 3),
"insufficient_data": False,
}

View file

@ -0,0 +1,102 @@
from __future__ import annotations
import warnings
from datetime import date
from dateutil.relativedelta import relativedelta
import numpy as np
import pandas as pd
warnings.filterwarnings("ignore")
def _project_months(from_date: date, n: int) -> list[str]:
months = []
d = from_date.replace(day=1)
for i in range(1, n + 1):
months.append((d + relativedelta(months=i)).strftime("%Y-%m"))
return months
def project_net_worth(df: pd.DataFrame, years: int = 5) -> dict:
"""
df columns: ds (monthly datetime), y (net_worth float)
Returns history + 3-scenario projections.
"""
n_months = years * 12
today = date.today()
future_dates = _project_months(today, n_months)
history = [
{"date": row["ds"].strftime("%Y-%m"), "value": round(float(row["y"]), 2)}
for _, row in df.iterrows()
]
if df.empty or len(df) < 2:
# No data — return flat projection from 0
last_val = float(df["y"].iloc[-1]) if not df.empty else 0.0
flat = [{"date": d, "value": round(last_val, 2)} for d in future_dates]
return {
"history": history,
"projections": {"conservative": flat, "base": flat, "optimistic": flat},
"insufficient_data": True,
}
try:
from statsmodels.tsa.holtwinters import ExponentialSmoothing
values = df["y"].tolist()
if len(values) >= 12:
model = ExponentialSmoothing(values, trend="add", seasonal="add", seasonal_periods=12)
elif len(values) >= 4:
model = ExponentialSmoothing(values, trend="add", seasonal=None)
else:
model = ExponentialSmoothing(values, trend="add", seasonal=None)
fit = model.fit(optimized=True, disp=False)
base_fcast = fit.forecast(n_months)
# Estimate monthly trend from the fit
monthly_trend = float(np.mean(np.diff(base_fcast[:12]))) if len(base_fcast) >= 12 else 0.0
last_val = float(values[-1])
# Scale trends for scenarios
def build_scenario(scale: float) -> list[dict]:
pts = []
v = last_val
for i, d in enumerate(future_dates):
v = float(base_fcast[i]) + (scale - 1.0) * monthly_trend * (i + 1)
pts.append({"date": d, "value": round(v, 2)})
return pts
return {
"history": history,
"projections": {
"conservative": build_scenario(0.5),
"base": [{"date": d, "value": round(float(v), 2)} for d, v in zip(future_dates, base_fcast)],
"optimistic": build_scenario(1.5),
},
"insufficient_data": False,
}
except Exception:
# Fallback: linear trend from last 2 values
trend = float(df["y"].iloc[-1]) - float(df["y"].iloc[-2])
last_val = float(df["y"].iloc[-1])
def linear_scenario(t_scale: float) -> list[dict]:
return [
{"date": d, "value": round(last_val + t_scale * trend * (i + 1), 2)}
for i, d in enumerate(future_dates)
]
return {
"history": history,
"projections": {
"conservative": linear_scenario(0.5),
"base": linear_scenario(1.0),
"optimistic": linear_scenario(1.5),
},
"insufficient_data": False,
}

View file

@ -0,0 +1,91 @@
from __future__ import annotations
import warnings
from datetime import date
from dateutil.relativedelta import relativedelta
import numpy as np
import pandas as pd
warnings.filterwarnings("ignore")
MIN_POINTS = 3
FORECAST_MONTHS = 3
def _next_month_starts(from_date: date, n: int) -> list[str]:
months = []
d = (from_date.replace(day=1) + relativedelta(months=1))
for _ in range(n):
months.append(d.strftime("%Y-%m-%d"))
d += relativedelta(months=1)
return months
def _fit_holt(values: list[float], n: int) -> tuple[list[float], list[float], list[float]]:
from statsmodels.tsa.holtwinters import ExponentialSmoothing
try:
if len(values) >= 12:
model = ExponentialSmoothing(values, trend="add", seasonal="add", seasonal_periods=12)
elif len(values) >= 4:
model = ExponentialSmoothing(values, trend="add", seasonal=None)
else:
model = ExponentialSmoothing(values, trend=None, seasonal=None)
fit = model.fit(optimized=True, disp=False)
forecast = fit.forecast(n)
sigma = float(np.std(fit.resid)) if len(fit.resid) > 1 else float(np.mean(values) * 0.15)
lower = np.maximum(0, forecast - 1.28 * sigma)
upper = forecast + 1.28 * sigma
return forecast.tolist(), lower.tolist(), upper.tolist()
except Exception:
avg = float(np.mean(values))
sigma = float(np.std(values)) if len(values) > 1 else avg * 0.15
return [avg] * n, [max(0, avg - 1.28 * sigma)] * n, [(avg + 1.28 * sigma)] * n
def forecast_spending(df: pd.DataFrame) -> list[dict]:
"""
df columns: category_id, category_name, ds (monthly), y (amount)
Returns list of category forecast dicts.
"""
if df.empty:
return []
today = date.today()
future_dates = _next_month_starts(today, FORECAST_MONTHS)
results = []
for (cat_id, cat_name), group in df.groupby(["category_id", "category_name"]):
group = group.sort_values("ds")
values = group["y"].tolist()
actuals = [
{"date": row["ds"].strftime("%Y-%m-%d"), "amount": row["y"]}
for _, row in group.iterrows()
]
if len(values) < MIN_POINTS:
avg = float(np.mean(values))
forecast_pts = [
{"date": d, "amount": round(avg, 2), "lower": round(avg * 0.7, 2), "upper": round(avg * 1.3, 2)}
for d in future_dates
]
else:
fcast, lower, upper = _fit_holt(values, FORECAST_MONTHS)
forecast_pts = [
{"date": d, "amount": round(max(0, f), 2), "lower": round(l, 2), "upper": round(u, 2)}
for d, f, l, u in zip(future_dates, fcast, lower, upper)
]
results.append({
"category_id": cat_id,
"category_name": cat_name,
"monthly_avg": round(float(np.mean(values)), 2),
"actuals": actuals[-6:], # last 6 months for display
"forecast": forecast_pts,
})
# Sort by monthly_avg descending (highest spend first)
results.sort(key=lambda x: x["monthly_avg"], reverse=True)
return results

View file

View file

@ -0,0 +1,59 @@
import uuid
from datetime import datetime
from decimal import Decimal
from typing import Literal
from pydantic import BaseModel, Field
AccountType = Literal[
"checking", "savings", "cash_isa", "stocks_shares_isa",
"credit_card", "investment", "cash", "crypto_wallet",
"loan", "mortgage", "pension", "other"
]
class AccountCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
institution: str | None = None
type: AccountType
currency: str = Field(default="GBP", min_length=3, max_length=10)
credit_limit: Decimal | None = None
interest_rate: Decimal | None = None
include_in_net_worth: bool = True
color: str = Field(default="#6366f1", pattern=r"^#[0-9a-fA-F]{6}$")
icon: str | None = None
notes: str | None = None
opening_balance: Decimal = Field(default=Decimal("0"))
class AccountUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=100)
institution: str | None = None
opening_balance: Decimal | None = None
credit_limit: Decimal | None = None
interest_rate: Decimal | None = None
include_in_net_worth: bool | None = None
is_active: bool | None = None
color: str | None = Field(default=None, pattern=r"^#[0-9a-fA-F]{6}$")
icon: str | None = None
notes: str | None = None
class AccountResponse(BaseModel):
id: uuid.UUID
name: str
institution: str | None
type: str
currency: str
current_balance: Decimal
credit_limit: Decimal | None
interest_rate: Decimal | None
is_active: bool
include_in_net_worth: bool
color: str
icon: str | None
notes: str | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View file

@ -0,0 +1,64 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, EmailStr, field_validator
class RegisterRequest(BaseModel):
email: EmailStr
password: str
display_name: str
@field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
if len(v) < 12:
raise ValueError("Password must be at least 12 characters")
if not any(c.isupper() for c in v):
raise ValueError("Password must contain an uppercase letter")
if not any(c.isdigit() for c in v):
raise ValueError("Password must contain a digit")
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
class TOTPChallengeResponse(BaseModel):
totp_required: bool = True
challenge_token: str
class TOTPLoginRequest(BaseModel):
challenge_token: str
totp_code: str
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
expires_in: int # seconds
class TOTPSetupResponse(BaseModel):
secret: str
qr_code_png_b64: str
backup_codes: list[str]
class TOTPVerifyRequest(BaseModel):
code: str
class SessionInfo(BaseModel):
id: uuid.UUID
ip_address: str | None
user_agent: str | None
last_active_at: datetime
expires_at: datetime
created_at: datetime
is_current: bool = False
model_config = {"from_attributes": True}

View file

@ -0,0 +1,65 @@
import uuid
from datetime import date as DateType, datetime
from decimal import Decimal
from typing import Literal
from pydantic import BaseModel, Field
BudgetPeriod = Literal["weekly", "monthly", "quarterly", "yearly"]
class BudgetCreate(BaseModel):
category_id: uuid.UUID
name: str = Field(..., min_length=1, max_length=200)
amount: Decimal = Field(..., gt=0)
currency: str = Field(default="GBP", min_length=3, max_length=10)
period: BudgetPeriod = "monthly"
start_date: DateType
end_date: DateType | None = None
rollover: bool = False
alert_threshold: Decimal = Field(default=Decimal("80"), ge=0, le=100)
class BudgetUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=200)
amount: Decimal | None = Field(default=None, gt=0)
period: BudgetPeriod | None = None
end_date: DateType | None = None
rollover: bool | None = None
alert_threshold: Decimal | None = Field(default=None, ge=0, le=100)
is_active: bool | None = None
class BudgetResponse(BaseModel):
id: uuid.UUID
category_id: uuid.UUID
name: str
amount: Decimal
currency: str
period: str
start_date: DateType
end_date: DateType | None
rollover: bool
alert_threshold: Decimal
is_active: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class BudgetSummaryItem(BaseModel):
budget_id: uuid.UUID
budget_name: str
category_id: uuid.UUID
category_name: str
period: str
budget_amount: Decimal
spent_amount: Decimal
remaining_amount: Decimal
percent_used: Decimal
is_over_budget: bool
alert_triggered: bool
currency: str
period_start: DateType
period_end: DateType

View file

@ -0,0 +1,103 @@
import uuid
from datetime import date as DateType, datetime
from decimal import Decimal
from typing import Literal
from pydantic import BaseModel, Field
InvestmentTxnType = Literal["buy", "sell", "dividend", "split", "fee", "transfer_in", "transfer_out"]
class AssetSearch(BaseModel):
id: uuid.UUID
symbol: str
name: str
type: str
currency: str
exchange: str | None
last_price: Decimal | None
price_change_24h: Decimal | None
data_source: str
model_config = {"from_attributes": True}
class AssetPricePoint(BaseModel):
date: DateType
open: Decimal | None
high: Decimal | None
low: Decimal | None
close: Decimal
volume: Decimal | None
model_config = {"from_attributes": True}
class HoldingCreate(BaseModel):
account_id: uuid.UUID
asset_id: uuid.UUID
quantity: Decimal = Field(..., gt=0)
avg_cost_basis: Decimal = Field(..., ge=0)
currency: str = Field(default="GBP", min_length=3, max_length=10)
class HoldingResponse(BaseModel):
id: uuid.UUID
account_id: uuid.UUID
asset_id: uuid.UUID
symbol: str
asset_name: str
asset_type: str
quantity: Decimal
avg_cost_basis: Decimal
current_price: Decimal | None
current_value: Decimal | None
cost_basis_total: Decimal
unrealised_gain: Decimal | None
unrealised_gain_pct: Decimal | None
currency: str
price_change_24h: Decimal | None
model_config = {"from_attributes": True}
class InvestmentTxnCreate(BaseModel):
holding_id: uuid.UUID
type: InvestmentTxnType
quantity: Decimal = Field(..., ge=0)
price: Decimal = Field(..., ge=0)
fees: Decimal = Field(default=Decimal("0"), ge=0)
currency: str = Field(default="GBP", min_length=3, max_length=10)
date: DateType
notes: str | None = None
class InvestmentTxnResponse(BaseModel):
id: uuid.UUID
holding_id: uuid.UUID
type: str
quantity: Decimal
price: Decimal
fees: Decimal
total_amount: Decimal
currency: str
date: DateType
created_at: datetime
model_config = {"from_attributes": True}
class PortfolioSummary(BaseModel):
total_value: Decimal
total_cost: Decimal
total_gain: Decimal
total_gain_pct: Decimal
currency: str
holdings: list[HoldingResponse]
class PerformanceMetrics(BaseModel):
twrr: Decimal | None
total_return: Decimal
total_return_pct: Decimal
currency: str

View file

@ -0,0 +1,96 @@
from datetime import date as DateType
from decimal import Decimal
from pydantic import BaseModel
class NetWorthPoint(BaseModel):
date: DateType
total_assets: Decimal
total_liabilities: Decimal
net_worth: Decimal
base_currency: str
class NetWorthReport(BaseModel):
points: list[NetWorthPoint]
current_net_worth: Decimal
change_30d: Decimal
change_30d_pct: Decimal
base_currency: str
class IncomeExpensePoint(BaseModel):
month: str # "2024-01"
income: Decimal
expenses: Decimal
net: Decimal
class IncomeExpenseReport(BaseModel):
points: list[IncomeExpensePoint]
total_income: Decimal
total_expenses: Decimal
avg_monthly_income: Decimal
avg_monthly_expenses: Decimal
currency: str
class CashFlowPoint(BaseModel):
date: DateType
inflow: Decimal
outflow: Decimal
net: Decimal
running_balance: Decimal
class CashFlowReport(BaseModel):
points: list[CashFlowPoint]
total_inflow: Decimal
total_outflow: Decimal
currency: str
class CategoryBreakdownItem(BaseModel):
category_id: str | None
category_name: str
amount: Decimal
percent: Decimal
transaction_count: int
class CategoryBreakdownReport(BaseModel):
items: list[CategoryBreakdownItem]
total: Decimal
currency: str
date_from: DateType
date_to: DateType
class BudgetVsActualItem(BaseModel):
budget_id: str
budget_name: str
category_name: str
budgeted: Decimal
actual: Decimal
variance: Decimal
percent_used: Decimal
class BudgetVsActualReport(BaseModel):
items: list[BudgetVsActualItem]
total_budgeted: Decimal
total_actual: Decimal
currency: str
class SpendingTrendPoint(BaseModel):
month: str
category_name: str
amount: Decimal
class SpendingTrendsReport(BaseModel):
points: list[SpendingTrendPoint]
categories: list[str]
currency: str

View file

@ -0,0 +1,77 @@
import uuid
from datetime import date as DateType, datetime
from decimal import Decimal
from typing import Literal
from pydantic import BaseModel, Field
TransactionType = Literal["income", "expense", "transfer", "investment"]
TransactionStatus = Literal["pending", "cleared", "reconciled", "void"]
class TransactionCreate(BaseModel):
account_id: uuid.UUID
transfer_account_id: uuid.UUID | None = None
category_id: uuid.UUID | None = None
type: TransactionType
status: TransactionStatus = "cleared"
amount: Decimal
currency: str = Field(default="GBP", min_length=3, max_length=10)
date: DateType
description: str = Field(..., min_length=1, max_length=500)
merchant: str | None = None
notes: str | None = None
tags: list[str] = []
is_recurring: bool = False
recurring_rule: dict | None = None
class TransactionUpdate(BaseModel):
category_id: uuid.UUID | None = None
status: TransactionStatus | None = None
amount: Decimal | None = None
date: DateType | None = None
description: str | None = Field(default=None, min_length=1, max_length=500)
merchant: str | None = None
notes: str | None = None
tags: list[str] | None = None
class TransactionFilter(BaseModel):
account_id: uuid.UUID | None = None
category_id: uuid.UUID | None = None
type: TransactionType | None = None
status: TransactionStatus | None = None
date_from: DateType | None = None
date_to: DateType | None = None
min_amount: Decimal | None = None
max_amount: Decimal | None = None
search: str | None = None
tags: list[str] = []
page: int = Field(default=1, ge=1)
page_size: int = Field(default=50, ge=1, le=200)
class TransactionResponse(BaseModel):
id: uuid.UUID
account_id: uuid.UUID
transfer_account_id: uuid.UUID | None
category_id: uuid.UUID | None
type: str
status: str
amount: Decimal
amount_base: Decimal | None
currency: str
base_currency: str
exchange_rate: Decimal | None
date: DateType
description: str
merchant: str | None
notes: str | None
tags: list[str]
is_recurring: bool
attachment_refs: list[dict] = []
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View file

View file

@ -0,0 +1,195 @@
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from decimal import Decimal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import encrypt_field, decrypt_field
from app.db.models.account import Account
from app.db.models.transaction import Transaction
from app.schemas.account import AccountCreate, AccountUpdate
# Account types that are liabilities (balance is negative contribution to net worth)
LIABILITY_TYPES = {"credit_card", "loan", "mortgage"}
class AccountError(Exception):
def __init__(self, detail: str, status_code: int = 400):
self.detail = detail
self.status_code = status_code
def _encrypt(value: str | None) -> bytes | None:
if value is None:
return None
return encrypt_field(value)
def _decrypt(data: bytes | None) -> str | None:
if not data:
return None
return decrypt_field(data)
def _to_response(account: Account) -> dict:
return {
"id": account.id,
"name": _decrypt(account.name_enc) or "",
"institution": _decrypt(account.institution_enc),
"type": account.type,
"currency": account.currency,
"current_balance": account.current_balance,
"credit_limit": account.credit_limit,
"interest_rate": account.interest_rate,
"is_active": account.is_active,
"include_in_net_worth": account.include_in_net_worth,
"color": account.color,
"icon": account.icon,
"notes": _decrypt(account.notes_enc),
"created_at": account.created_at,
"updated_at": account.updated_at,
}
async def create_account(
db: AsyncSession,
user_id: uuid.UUID,
data: AccountCreate,
) -> dict:
now = datetime.now(timezone.utc)
account = Account(
user_id=user_id,
name_enc=encrypt_field(data.name),
institution_enc=_encrypt(data.institution),
type=data.type,
currency=data.currency,
current_balance=data.opening_balance,
credit_limit=data.credit_limit,
interest_rate=data.interest_rate,
include_in_net_worth=data.include_in_net_worth,
color=data.color,
icon=data.icon,
notes_enc=_encrypt(data.notes),
created_at=now,
updated_at=now,
)
db.add(account)
await db.flush()
return _to_response(account)
async def list_accounts(db: AsyncSession, user_id: uuid.UUID) -> list[dict]:
result = await db.execute(
select(Account).where(
Account.user_id == user_id,
Account.deleted_at.is_(None),
).order_by(Account.created_at)
)
return [_to_response(a) for a in result.scalars()]
async def get_account(db: AsyncSession, account_id: uuid.UUID, user_id: uuid.UUID) -> Account:
result = await db.execute(
select(Account).where(
Account.id == account_id,
Account.user_id == user_id,
Account.deleted_at.is_(None),
)
)
account = result.scalar_one_or_none()
if not account:
raise AccountError("Account not found", status_code=404)
return account
async def update_account(
db: AsyncSession,
account_id: uuid.UUID,
user_id: uuid.UUID,
data: AccountUpdate,
) -> dict:
account = await get_account(db, account_id, user_id)
now = datetime.now(timezone.utc)
if data.name is not None:
account.name_enc = encrypt_field(data.name)
if data.institution is not None:
account.institution_enc = _encrypt(data.institution)
if data.opening_balance is not None:
account.current_balance = data.opening_balance
if data.credit_limit is not None:
account.credit_limit = data.credit_limit
if data.interest_rate is not None:
account.interest_rate = data.interest_rate
if data.include_in_net_worth is not None:
account.include_in_net_worth = data.include_in_net_worth
if data.is_active is not None:
account.is_active = data.is_active
if data.color is not None:
account.color = data.color
if data.icon is not None:
account.icon = data.icon
if data.notes is not None:
account.notes_enc = _encrypt(data.notes)
account.updated_at = now
await db.flush()
return _to_response(account)
async def delete_account(
db: AsyncSession,
account_id: uuid.UUID,
user_id: uuid.UUID,
) -> None:
account = await get_account(db, account_id, user_id)
account.deleted_at = datetime.now(timezone.utc)
account.updated_at = datetime.now(timezone.utc)
await db.flush()
async def recalculate_balance(db: AsyncSession, account_id: uuid.UUID) -> None:
"""Recompute current_balance from all non-deleted transactions."""
result = await db.execute(
select(func.sum(Transaction.amount)).where(
Transaction.account_id == account_id,
Transaction.deleted_at.is_(None),
)
)
total = result.scalar_one_or_none() or Decimal("0")
account = await db.get(Account, account_id)
if account:
account.current_balance = total
account.updated_at = datetime.now(timezone.utc)
await db.flush()
async def get_net_worth(db: AsyncSession, user_id: uuid.UUID, base_currency: str) -> dict:
accounts = await db.execute(
select(Account).where(
Account.user_id == user_id,
Account.include_in_net_worth == True,
Account.deleted_at.is_(None),
)
)
total_assets = Decimal("0")
total_liabilities = Decimal("0")
for account in accounts.scalars():
# TODO Phase 3: convert to base_currency via FX rates
bal = account.current_balance
if account.type in LIABILITY_TYPES:
total_liabilities += abs(bal)
else:
total_assets += bal
return {
"total_assets": total_assets,
"total_liabilities": total_liabilities,
"net_worth": total_assets - total_liabilities,
"base_currency": base_currency,
}

View file

@ -0,0 +1,258 @@
"""
Authentication service: register, login, TOTP, sessions, brute-force protection.
"""
from __future__ import annotations
import base64
import uuid
from datetime import datetime, timedelta, timezone
from jose import JWTError
from redis.asyncio import Redis
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import get_settings
from app.core.security import (
create_access_token,
create_refresh_token,
decrypt_field,
decode_token,
encrypt_field,
generate_backup_codes,
generate_csrf_token,
generate_totp_qr_png,
generate_totp_secret,
hash_password,
hash_token,
verify_password,
verify_totp,
)
from app.db.models.session import Session
from app.db.models.user import User
class AuthError(Exception):
def __init__(self, detail: str, status_code: int = 401):
self.detail = detail
self.status_code = status_code
async def _lockout_key(email: str) -> str:
return f"lockout:{email}"
async def _check_and_record_failure(redis: Redis, email: str, settings) -> None:
key = await _lockout_key(email)
attempts = await redis.incr(key)
if attempts == 1:
await redis.expire(key, settings.lockout_base_seconds)
if attempts >= settings.max_login_attempts:
lockout_seconds = settings.lockout_base_seconds * (2 ** (attempts - settings.max_login_attempts))
await redis.expire(key, min(lockout_seconds, 86400)) # cap at 24h
async def _is_locked_out(redis: Redis, email: str) -> bool:
key = await _lockout_key(email)
val = await redis.get(key)
if val is None:
return False
settings = get_settings()
return int(val) >= settings.max_login_attempts
async def register_user(db: AsyncSession, email: str, password: str, display_name: str) -> User:
settings = get_settings()
# Single-user: block registration if user already exists
if not settings.allow_registration:
count = await db.scalar(select(func.count()).select_from(User).where(User.deleted_at.is_(None)))
if count and count > 0:
raise AuthError("Registration is disabled", status_code=403)
existing = await db.scalar(select(User).where(User.email == email))
if existing:
raise AuthError("Email already registered", status_code=409)
now = datetime.now(timezone.utc)
user = User(
email=email,
password_hash=hash_password(password),
display_name=display_name,
base_currency=settings.base_currency,
created_at=now,
updated_at=now,
)
db.add(user)
await db.flush()
return user
async def authenticate_user(
db: AsyncSession,
redis: Redis,
email: str,
password: str,
ip: str | None,
user_agent: str | None,
) -> tuple[User, str, str] | tuple[User, None, None]:
"""
Returns (user, access_token, refresh_token) if no TOTP required,
or (user, None, None) if TOTP challenge needed.
Raises AuthError on failure.
"""
settings = get_settings()
if await _is_locked_out(redis, email):
raise AuthError("Account temporarily locked due to too many failed attempts", status_code=429)
user = await db.scalar(
select(User).where(User.email == email, User.deleted_at.is_(None))
)
if not user or not verify_password(password, user.password_hash):
await _check_and_record_failure(redis, email, settings)
raise AuthError("Invalid email or password")
# Clear lockout on success
await redis.delete(await _lockout_key(email))
if user.totp_enabled:
return user, None, None # Caller creates challenge token
tokens = await _create_session(db, user, ip, user_agent)
return user, *tokens
async def _create_session(
db: AsyncSession,
user: User,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
settings = get_settings()
access_token = create_access_token(str(user.id))
refresh_token = create_refresh_token(str(user.id))
now = datetime.now(timezone.utc)
session = Session(
user_id=user.id,
token_hash=hash_token(access_token),
ip_address=ip,
user_agent=user_agent,
last_active_at=now,
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
created_at=now,
)
db.add(session)
await db.flush()
# Update user login info
user.last_login_at = now
user.last_login_ip = ip
user.updated_at = now
return access_token, refresh_token
async def complete_totp_login(
db: AsyncSession,
challenge_token: str,
totp_code: str,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
try:
payload = decode_token(challenge_token, token_type="totp_challenge")
user_id = uuid.UUID(payload["sub"])
except (JWTError, ValueError, KeyError):
raise AuthError("Invalid or expired challenge token")
user = await db.get(User, user_id)
if not user or not user.totp_enabled or not user.totp_secret_enc:
raise AuthError("Invalid challenge")
secret = decrypt_field(bytes.fromhex(user.totp_secret_enc) if isinstance(user.totp_secret_enc, str) else user.totp_secret_enc)
if not verify_totp(secret, totp_code):
raise AuthError("Invalid TOTP code")
return await _create_session(db, user, ip, user_agent)
def create_totp_challenge_token(user_id: uuid.UUID) -> str:
from app.core.security import create_access_token
from datetime import timedelta
from datetime import datetime, timezone
from app.config import get_settings
from jose import jwt
from pathlib import Path
settings = get_settings()
now = datetime.now(timezone.utc)
payload = {
"sub": str(user_id),
"iat": now,
"exp": now + timedelta(minutes=5),
"type": "totp_challenge",
}
private_key = Path(settings.jwt_private_key_file).read_text()
return jwt.encode(payload, private_key, algorithm=settings.jwt_algorithm)
async def setup_totp(user: User, db: AsyncSession) -> tuple[str, str, list[str]]:
"""Generate TOTP secret, QR code, and backup codes. Does not enable TOTP yet."""
secret = generate_totp_secret()
qr_png = generate_totp_qr_png(secret, user.email)
backup_codes = generate_backup_codes(8)
return secret, base64.b64encode(qr_png).decode(), backup_codes
async def enable_totp(user: User, db: AsyncSession, secret: str, code: str) -> None:
if not verify_totp(secret, code):
raise AuthError("Invalid TOTP code — setup failed", status_code=400)
encrypted = encrypt_field(secret)
user.totp_secret_enc = encrypted.hex()
user.totp_enabled = True
user.updated_at = datetime.now(timezone.utc)
await db.flush()
async def disable_totp(user: User, db: AsyncSession, password: str) -> None:
if not verify_password(password, user.password_hash):
raise AuthError("Incorrect password", status_code=400)
user.totp_secret_enc = None
user.totp_enabled = False
user.totp_backup_codes_enc = None
user.updated_at = datetime.now(timezone.utc)
await db.flush()
async def revoke_session(db: AsyncSession, session_id: uuid.UUID, user_id: uuid.UUID) -> None:
session = await db.get(Session, session_id)
if not session or session.user_id != user_id:
raise AuthError("Session not found", status_code=404)
session.revoked_at = datetime.now(timezone.utc)
await db.flush()
async def revoke_all_sessions(db: AsyncSession, user_id: uuid.UUID, except_token_hash: str | None = None) -> None:
from sqlalchemy import update
stmt = (
update(Session)
.where(Session.user_id == user_id, Session.revoked_at.is_(None))
)
if except_token_hash:
stmt = stmt.where(Session.token_hash != except_token_hash)
stmt = stmt.values(revoked_at=datetime.now(timezone.utc))
await db.execute(stmt)
async def get_sessions(db: AsyncSession, user_id: uuid.UUID) -> list[Session]:
result = await db.execute(
select(Session).where(
Session.user_id == user_id,
Session.revoked_at.is_(None),
Session.expires_at > datetime.now(timezone.utc),
).order_by(Session.created_at.desc())
)
return list(result.scalars())

View file

@ -0,0 +1,137 @@
import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from dateutil.relativedelta import relativedelta
from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.budget import Budget
from app.db.models.category import Category
from app.db.models.transaction import Transaction
from app.schemas.budget import BudgetCreate, BudgetSummaryItem, BudgetUpdate
def _period_bounds(period: str, ref: date) -> tuple[date, date]:
if period == "weekly":
start = ref - relativedelta(days=ref.weekday())
end = start + relativedelta(days=6)
elif period == "monthly":
start = ref.replace(day=1)
end = (start + relativedelta(months=1)) - relativedelta(days=1)
elif period == "quarterly":
q = (ref.month - 1) // 3
start = date(ref.year, q * 3 + 1, 1)
end = (start + relativedelta(months=3)) - relativedelta(days=1)
else: # yearly
start = date(ref.year, 1, 1)
end = date(ref.year, 12, 31)
return start, end
async def create_budget(db: AsyncSession, user_id: uuid.UUID, data: BudgetCreate) -> Budget:
now = datetime.now(timezone.utc)
budget = Budget(
id=uuid.uuid4(),
user_id=user_id,
category_id=data.category_id,
name=data.name,
amount=data.amount,
currency=data.currency,
period=data.period,
start_date=data.start_date,
end_date=data.end_date,
rollover=data.rollover,
alert_threshold=data.alert_threshold,
is_active=True,
created_at=now,
updated_at=now,
)
db.add(budget)
await db.flush()
await db.refresh(budget)
return budget
async def list_budgets(db: AsyncSession, user_id: uuid.UUID, active_only: bool = True) -> list[Budget]:
q = select(Budget).where(Budget.user_id == user_id)
if active_only:
q = q.where(Budget.is_active == True) # noqa: E712
q = q.order_by(Budget.name)
result = await db.execute(q)
return list(result.scalars().all())
async def get_budget(db: AsyncSession, user_id: uuid.UUID, budget_id: uuid.UUID) -> Budget | None:
result = await db.execute(
select(Budget).where(Budget.id == budget_id, Budget.user_id == user_id)
)
return result.scalar_one_or_none()
async def update_budget(db: AsyncSession, budget: Budget, data: BudgetUpdate) -> Budget:
for field, value in data.model_dump(exclude_unset=True).items():
setattr(budget, field, value)
budget.updated_at = datetime.now(timezone.utc)
await db.flush()
await db.refresh(budget)
return budget
async def delete_budget(db: AsyncSession, budget: Budget) -> None:
await db.delete(budget)
await db.flush()
async def get_budget_summary(db: AsyncSession, user_id: uuid.UUID) -> list[BudgetSummaryItem]:
budgets = await list_budgets(db, user_id, active_only=True)
today = date.today()
items: list[BudgetSummaryItem] = []
for budget in budgets:
period_start, period_end = _period_bounds(budget.period, today)
# Fetch category name
cat_result = await db.execute(select(Category).where(Category.id == budget.category_id))
category = cat_result.scalar_one_or_none()
cat_name = category.name if category else "Unknown"
# Sum actual spending in this period
spent_result = await db.execute(
select(func.coalesce(func.sum(func.abs(Transaction.amount)), Decimal("0")))
.where(
and_(
Transaction.user_id == user_id,
Transaction.category_id == budget.category_id,
Transaction.type == "expense",
Transaction.status != "void",
Transaction.date >= period_start,
Transaction.date <= period_end,
Transaction.deleted_at.is_(None),
)
)
)
spent = Decimal(str(spent_result.scalar() or 0))
remaining = budget.amount - spent
pct = (spent / budget.amount * 100) if budget.amount > 0 else Decimal("0")
items.append(
BudgetSummaryItem(
budget_id=budget.id,
budget_name=budget.name,
category_id=budget.category_id,
category_name=cat_name,
period=budget.period,
budget_amount=budget.amount,
spent_amount=spent,
remaining_amount=remaining,
percent_used=pct.quantize(Decimal("0.01")),
is_over_budget=spent > budget.amount,
alert_triggered=pct >= budget.alert_threshold,
currency=budget.currency,
period_start=period_start,
period_end=period_end,
)
)
return items

View file

@ -0,0 +1,135 @@
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.category import Category
SYSTEM_CATEGORIES = [
# Income
{"name": "Salary", "type": "income", "icon": "briefcase", "color": "#22c55e"},
{"name": "Freelance", "type": "income", "icon": "laptop", "color": "#22c55e"},
{"name": "Investment Income", "type": "income", "icon": "trending-up", "color": "#22c55e"},
{"name": "Rental Income", "type": "income", "icon": "home", "color": "#22c55e"},
{"name": "Benefits", "type": "income", "icon": "shield", "color": "#22c55e"},
{"name": "Other Income", "type": "income", "icon": "plus-circle", "color": "#22c55e"},
# Expenses — Housing
{"name": "Rent / Mortgage", "type": "expense", "icon": "home", "color": "#6366f1"},
{"name": "Council Tax", "type": "expense", "icon": "landmark", "color": "#6366f1"},
{"name": "Home Insurance", "type": "expense", "icon": "shield", "color": "#6366f1"},
{"name": "Home Maintenance", "type": "expense", "icon": "wrench", "color": "#6366f1"},
# Utilities
{"name": "Electricity", "type": "expense", "icon": "zap", "color": "#f59e0b"},
{"name": "Gas", "type": "expense", "icon": "flame", "color": "#f59e0b"},
{"name": "Water", "type": "expense", "icon": "droplets", "color": "#f59e0b"},
{"name": "Internet", "type": "expense", "icon": "wifi", "color": "#f59e0b"},
{"name": "Phone", "type": "expense", "icon": "smartphone", "color": "#f59e0b"},
# Food
{"name": "Groceries", "type": "expense", "icon": "shopping-cart", "color": "#ec4899"},
{"name": "Eating Out", "type": "expense", "icon": "utensils", "color": "#ec4899"},
{"name": "Coffee", "type": "expense", "icon": "coffee", "color": "#ec4899"},
{"name": "Takeaway", "type": "expense", "icon": "package", "color": "#ec4899"},
# Transport
{"name": "Fuel", "type": "expense", "icon": "fuel", "color": "#0ea5e9"},
{"name": "Public Transport", "type": "expense", "icon": "bus", "color": "#0ea5e9"},
{"name": "Car Insurance", "type": "expense", "icon": "car", "color": "#0ea5e9"},
{"name": "Car Maintenance", "type": "expense", "icon": "wrench", "color": "#0ea5e9"},
{"name": "Parking", "type": "expense", "icon": "parking-circle", "color": "#0ea5e9"},
{"name": "Taxi / Ride share", "type": "expense", "icon": "map-pin", "color": "#0ea5e9"},
# Health
{"name": "Healthcare", "type": "expense", "icon": "heart-pulse", "color": "#ef4444"},
{"name": "Pharmacy", "type": "expense", "icon": "pill", "color": "#ef4444"},
{"name": "Gym", "type": "expense", "icon": "dumbbell", "color": "#ef4444"},
# Personal
{"name": "Clothing", "type": "expense", "icon": "shirt", "color": "#a855f7"},
{"name": "Personal Care", "type": "expense", "icon": "sparkles", "color": "#a855f7"},
{"name": "Subscriptions", "type": "expense", "icon": "repeat", "color": "#a855f7"},
{"name": "Entertainment", "type": "expense", "icon": "tv", "color": "#a855f7"},
{"name": "Holidays", "type": "expense", "icon": "plane", "color": "#a855f7"},
# Finance
{"name": "Loan Repayment", "type": "expense", "icon": "credit-card", "color": "#64748b"},
{"name": "Mortgage Payment", "type": "expense", "icon": "building", "color": "#64748b"},
{"name": "Bank Charges", "type": "expense", "icon": "landmark", "color": "#64748b"},
{"name": "Interest Paid", "type": "expense", "icon": "percent", "color": "#64748b"},
# Savings
{"name": "Savings", "type": "expense", "icon": "piggy-bank", "color": "#10b981"},
{"name": "Investments", "type": "expense", "icon": "trending-up", "color": "#10b981"},
# Other
{"name": "Gifts", "type": "expense", "icon": "gift", "color": "#f97316"},
{"name": "Education", "type": "expense", "icon": "graduation-cap", "color": "#f97316"},
{"name": "Other Expense", "type": "expense", "icon": "more-horizontal", "color": "#64748b"},
# Transfers
{"name": "Transfer", "type": "transfer", "icon": "arrow-left-right", "color": "#94a3b8"},
]
async def seed_system_categories(db: AsyncSession) -> None:
existing = await db.scalar(
select(Category).where(Category.is_system == True).limit(1)
)
if existing:
return
now = datetime.now(timezone.utc)
for i, cat in enumerate(SYSTEM_CATEGORIES):
db.add(Category(
user_id=None,
name=cat["name"],
type=cat["type"],
icon=cat.get("icon"),
color=cat.get("color"),
is_system=True,
sort_order=i,
created_at=now,
))
await db.flush()
async def list_categories(db: AsyncSession, user_id: uuid.UUID) -> list[dict]:
result = await db.execute(
select(Category).where(
(Category.user_id == user_id) | (Category.user_id.is_(None))
).order_by(Category.type, Category.sort_order, Category.name)
)
cats = result.scalars().all()
return [
{
"id": str(c.id),
"name": c.name,
"type": c.type,
"icon": c.icon,
"color": c.color,
"is_system": c.is_system,
"parent_id": str(c.parent_id) if c.parent_id else None,
"sort_order": c.sort_order,
}
for c in cats
]
async def create_category(
db: AsyncSession,
user_id: uuid.UUID,
name: str,
type_: str,
icon: str | None = None,
color: str | None = None,
parent_id: uuid.UUID | None = None,
) -> dict:
now = datetime.now(timezone.utc)
cat = Category(
user_id=user_id,
name=name,
type=type_,
icon=icon,
color=color,
parent_id=parent_id,
is_system=False,
created_at=now,
)
db.add(cat)
await db.flush()
return {"id": str(cat.id), "name": cat.name, "type": cat.type, "icon": cat.icon, "color": cat.color, "is_system": False}

View file

@ -0,0 +1,237 @@
"""
Auto-detect CSV bank export formats and produce a column mapping.
Supports: Monzo, Starling, Revolut, Barclays, Lloyds, NatWest/RBS, HSBC, Santander.
Falls back to a generic best-effort mapping for unknown formats.
"""
from __future__ import annotations
import csv
import io
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class CsvMapping:
date: str
description: str
amount: str | None = None # single signed amount column
debit: str | None = None # separate debit column (positive value = money out)
credit: str | None = None # separate credit column (positive value = money in)
balance: str | None = None
reference: str | None = None
detected_format: str | None = None
def is_split(self) -> bool:
return self.debit is not None and self.credit is not None
KNOWN_FORMATS: list[dict] = [
{
"name": "Monzo",
"detect": lambda h: {"transaction id", "emoji"}.issubset(h),
"date": "Date",
"description": "Name",
"amount": "Amount",
"balance": None,
"reference": "Notes and #tags",
},
{
"name": "Starling",
"detect": lambda h: {"counter party", "spending category"}.issubset(h),
"date": "Date",
"description": "Counter Party",
"amount": "Amount (GBP)",
"balance": "Balance (GBP)",
"reference": "Reference",
},
{
"name": "Revolut",
"detect": lambda h: {"product", "started date", "completed date"}.issubset(h),
"date": "Started Date",
"description": "Description",
"amount": "Amount",
"balance": "Balance",
"reference": None,
},
{
"name": "Barclays",
"detect": lambda h: {"subcategory", "memo", "number"}.issubset(h),
"date": "Date",
"description": "Memo",
"amount": "Amount",
"balance": None,
"reference": "Subcategory",
},
{
"name": "Lloyds Bank",
"detect": lambda h: {"transaction date", "debit amount", "credit amount", "transaction description"}.issubset(h),
"date": "Transaction Date",
"description": "Transaction Description",
"debit": "Debit Amount",
"credit": "Credit Amount",
"balance": "Balance",
"reference": None,
},
{
"name": "Halifax",
"detect": lambda h: {"transaction date", "debit amount", "credit amount", "transaction description"}.issubset(h),
"date": "Transaction Date",
"description": "Transaction Description",
"debit": "Debit Amount",
"credit": "Credit Amount",
"balance": "Balance",
"reference": None,
},
{
"name": "NatWest / RBS",
"detect": lambda h: {"date", "type", "description", "value", "balance"}.issubset(h) and "value" in h,
"date": "Date",
"description": "Description",
"amount": "Value",
"balance": "Balance",
"reference": None,
},
{
"name": "HSBC",
"detect": lambda h: h == {"date", "description", "amount"} or h == {"date", "description", "debit", "credit", "balance"},
"date": "Date",
"description": "Description",
"amount": "Amount",
"balance": None,
"reference": None,
},
{
"name": "Santander",
"detect": lambda h: {"date", "description", "debit", "credit", "balance"}.issubset(h),
"date": "Date",
"description": "Description",
"debit": "Debit",
"credit": "Credit",
"balance": "Balance",
"reference": None,
},
{
"name": "Nationwide",
"detect": lambda h: {"date", "transaction", "payments out", "payments in", "balance"}.issubset(h),
"date": "Date",
"description": "Transaction",
"debit": "Payments Out",
"credit": "Payments In",
"balance": "Balance",
"reference": None,
},
]
def _normalise_headers(raw_headers: list[str]) -> dict[str, str]:
"""Return {normalised_key: original_header}."""
return {h.strip().lower(): h.strip() for h in raw_headers if h}
def detect_format(raw_headers: list[str]) -> CsvMapping:
norm = _normalise_headers(raw_headers)
norm_set = set(norm.keys())
for fmt in KNOWN_FORMATS:
if fmt["detect"](norm_set):
# Map logical names → actual header using case-insensitive lookup
def resolve(col: str | None) -> str | None:
if col is None:
return None
return norm.get(col.strip().lower(), col)
if "debit" in fmt:
return CsvMapping(
date=resolve(fmt["date"]) or fmt["date"],
description=resolve(fmt["description"]) or fmt["description"],
debit=resolve(fmt["debit"]),
credit=resolve(fmt["credit"]),
balance=resolve(fmt.get("balance")),
reference=resolve(fmt.get("reference")),
detected_format=fmt["name"],
)
else:
return CsvMapping(
date=resolve(fmt["date"]) or fmt["date"],
description=resolve(fmt["description"]) or fmt["description"],
amount=resolve(fmt["amount"]),
balance=resolve(fmt.get("balance")),
reference=resolve(fmt.get("reference")),
detected_format=fmt["name"],
)
# Generic fallback: guess by common column name patterns
return _generic_mapping(norm)
def _generic_mapping(norm: dict[str, str]) -> CsvMapping:
def find(*candidates: str) -> str | None:
for c in candidates:
if c in norm:
return norm[c]
return None
date_col = find("date", "transaction date", "trans date", "value date", "posting date")
desc_col = find("description", "narrative", "details", "memo", "payee", "merchant", "name", "counter party")
amt_col = find("amount", "value", "net amount", "transaction amount")
debit_col = find("debit", "debit amount", "payments out", "money out", "withdrawal")
credit_col = find("credit", "credit amount", "payments in", "money in", "deposit")
bal_col = find("balance", "running balance")
ref_col = find("reference", "notes", "tags", "category")
if not date_col:
date_col = list(norm.values())[0] if norm else "date"
if not desc_col:
desc_col = list(norm.values())[1] if len(norm) > 1 else "description"
if debit_col and credit_col:
return CsvMapping(
date=date_col,
description=desc_col,
debit=debit_col,
credit=credit_col,
balance=bal_col,
reference=ref_col,
detected_format=None,
)
return CsvMapping(
date=date_col,
description=desc_col,
amount=amt_col or (list(norm.values())[2] if len(norm) > 2 else "amount"),
balance=bal_col,
reference=ref_col,
detected_format=None,
)
def parse_csv_content(content: bytes) -> tuple[list[str], list[dict]]:
"""Decode and return (headers, rows)."""
for enc in ("utf-8-sig", "utf-8", "latin-1"):
try:
text = content.decode(enc)
break
except UnicodeDecodeError:
continue
else:
raise ValueError("Cannot decode file — try saving as UTF-8")
# Some bank exports (Lloyds, Barclays) include preamble lines before the header
lines = text.splitlines()
header_idx = 0
for i, line in enumerate(lines):
if "," in line and len(line.split(",")) >= 2:
header_idx = i
break
cleaned = "\n".join(lines[header_idx:])
reader = csv.DictReader(io.StringIO(cleaned))
headers = [h.strip() for h in (reader.fieldnames or []) if h and h.strip()]
rows = []
for row in reader:
clean_row = {k.strip(): (v.strip() if v else "") for k, v in row.items() if k and k.strip()}
if any(clean_row.values()):
rows.append(clean_row)
return headers, rows

View file

@ -0,0 +1,300 @@
import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.asset import Asset
from app.db.models.asset_price import AssetPrice
from app.db.models.investment_holding import InvestmentHolding
from app.db.models.investment_transaction import InvestmentTransaction
from app.schemas.investment import (
HoldingCreate,
HoldingResponse,
InvestmentTxnCreate,
PerformanceMetrics,
PortfolioSummary,
)
async def _get_asset(db: AsyncSession, asset_id: uuid.UUID) -> Asset | None:
result = await db.execute(select(Asset).where(Asset.id == asset_id))
return result.scalar_one_or_none()
def _holding_to_response(holding: InvestmentHolding, asset: Asset) -> HoldingResponse:
cost_basis_total = holding.quantity * holding.avg_cost_basis
current_price = asset.last_price
current_value = holding.quantity * current_price if current_price else None
unrealised_gain = (current_value - cost_basis_total) if current_value is not None else None
unrealised_gain_pct = None
if unrealised_gain is not None and cost_basis_total > 0:
unrealised_gain_pct = (unrealised_gain / cost_basis_total * 100).quantize(Decimal("0.01"))
return HoldingResponse(
id=holding.id,
account_id=holding.account_id,
asset_id=holding.asset_id,
symbol=asset.symbol,
asset_name=asset.name,
asset_type=asset.type,
quantity=holding.quantity,
avg_cost_basis=holding.avg_cost_basis,
current_price=current_price,
current_value=current_value,
cost_basis_total=cost_basis_total,
unrealised_gain=unrealised_gain,
unrealised_gain_pct=unrealised_gain_pct,
currency=holding.currency,
price_change_24h=asset.price_change_24h,
)
async def get_portfolio(db: AsyncSession, user_id: uuid.UUID) -> PortfolioSummary:
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.user_id == user_id,
InvestmentHolding.quantity > 0,
)
)
holdings = result.scalars().all()
responses = []
total_value = Decimal("0")
total_cost = Decimal("0")
for h in holdings:
asset = await _get_asset(db, h.asset_id)
if not asset:
continue
r = _holding_to_response(h, asset)
responses.append(r)
total_cost += r.cost_basis_total
if r.current_value is not None:
total_value += r.current_value
total_gain = total_value - total_cost
total_gain_pct = (total_gain / total_cost * 100).quantize(Decimal("0.01")) if total_cost > 0 else Decimal("0")
return PortfolioSummary(
total_value=total_value,
total_cost=total_cost,
total_gain=total_gain,
total_gain_pct=total_gain_pct,
currency="GBP",
holdings=responses,
)
async def get_holding(db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID) -> InvestmentHolding | None:
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.id == holding_id,
InvestmentHolding.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def create_holding(db: AsyncSession, user_id: uuid.UUID, data: HoldingCreate) -> InvestmentHolding:
now = datetime.now(timezone.utc)
# Check if holding already exists for this account+asset
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.user_id == user_id,
InvestmentHolding.account_id == data.account_id,
InvestmentHolding.asset_id == data.asset_id,
)
)
existing = result.scalar_one_or_none()
if existing:
return existing
holding = InvestmentHolding(
id=uuid.uuid4(),
user_id=user_id,
account_id=data.account_id,
asset_id=data.asset_id,
quantity=data.quantity,
avg_cost_basis=data.avg_cost_basis,
currency=data.currency,
created_at=now,
updated_at=now,
)
db.add(holding)
await db.flush()
await db.refresh(holding)
return holding
async def add_investment_transaction(
db: AsyncSession, user_id: uuid.UUID, data: InvestmentTxnCreate
) -> InvestmentTransaction:
holding = await get_holding(db, user_id, data.holding_id)
if not holding:
raise ValueError("Holding not found")
total = data.quantity * data.price + data.fees
txn = InvestmentTransaction(
id=uuid.uuid4(),
user_id=user_id,
holding_id=data.holding_id,
type=data.type,
quantity=data.quantity,
price=data.price,
fees=data.fees,
total_amount=total,
currency=data.currency,
date=data.date,
created_at=datetime.now(timezone.utc),
)
db.add(txn)
# Update holding quantity and avg cost basis
if data.type == "buy" or data.type == "transfer_in":
new_qty = holding.quantity + data.quantity
if new_qty > 0:
holding.avg_cost_basis = (
(holding.quantity * holding.avg_cost_basis + data.quantity * data.price)
/ new_qty
)
holding.quantity = new_qty
elif data.type == "sell" or data.type == "transfer_out":
holding.quantity = max(Decimal("0"), holding.quantity - data.quantity)
elif data.type == "split":
if data.price > 0:
holding.quantity = holding.quantity * data.quantity
holding.avg_cost_basis = holding.avg_cost_basis / data.quantity
# dividend and fee don't affect quantity/cost basis
holding.updated_at = datetime.now(timezone.utc)
await db.flush()
await db.refresh(txn)
return txn
async def list_investment_transactions(
db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID
) -> list[InvestmentTransaction]:
result = await db.execute(
select(InvestmentTransaction)
.where(
InvestmentTransaction.user_id == user_id,
InvestmentTransaction.holding_id == holding_id,
)
.order_by(InvestmentTransaction.date.desc())
)
return list(result.scalars().all())
async def get_performance(db: AsyncSession, user_id: uuid.UUID) -> PerformanceMetrics:
portfolio = await get_portfolio(db, user_id)
total_return = portfolio.total_gain
total_return_pct = portfolio.total_gain_pct
return PerformanceMetrics(
twrr=None, # full TWRR requires snapshot history — placeholder
total_return=total_return,
total_return_pct=total_return_pct,
currency="GBP",
)
async def get_or_create_asset(
db: AsyncSession, symbol: str, name: str, asset_type: str,
currency: str, data_source: str, data_source_id: str | None,
exchange: str | None = None,
) -> Asset:
result = await db.execute(
select(Asset).where(Asset.symbol == symbol.upper(), Asset.data_source == data_source)
)
existing = result.scalar_one_or_none()
if existing:
return existing
now = datetime.now(timezone.utc)
asset = Asset(
id=uuid.uuid4(),
symbol=symbol.upper(),
name=name,
type=asset_type,
currency=currency,
exchange=exchange,
data_source=data_source,
data_source_id=data_source_id,
is_active=True,
created_at=now,
updated_at=now,
)
db.add(asset)
await db.flush()
await db.refresh(asset)
return asset
async def update_asset_price(
db: AsyncSession, asset: Asset, price: Decimal, change_24h: Decimal | None
) -> None:
asset.last_price = price
asset.price_change_24h = change_24h
asset.last_price_at = datetime.now(timezone.utc)
asset.updated_at = datetime.now(timezone.utc)
await db.flush()
async def upsert_price_history(db: AsyncSession, asset_id: uuid.UUID, rows: list[dict]) -> int:
count = 0
for row in rows:
result = await db.execute(
select(AssetPrice).where(AssetPrice.asset_id == asset_id, AssetPrice.date == row["date"])
)
existing = result.scalar_one_or_none()
if existing:
existing.open = row["open"]
existing.high = row["high"]
existing.low = row["low"]
existing.close = row["close"]
existing.volume = row["volume"]
else:
db.add(AssetPrice(
id=uuid.uuid4(),
asset_id=asset_id,
date=row["date"],
open=row.get("open"),
high=row.get("high"),
low=row.get("low"),
close=row["close"],
volume=row.get("volume"),
created_at=datetime.now(timezone.utc),
))
count += 1
await db.flush()
return count
async def get_price_history(
db: AsyncSession, asset_id: uuid.UUID, days: int = 365
) -> list[AssetPrice]:
from datetime import timedelta
cutoff = date.today() - timedelta(days=days)
result = await db.execute(
select(AssetPrice)
.where(AssetPrice.asset_id == asset_id, AssetPrice.date >= cutoff)
.order_by(AssetPrice.date.asc())
)
return list(result.scalars().all())
async def search_assets(db: AsyncSession, query: str) -> list[Asset]:
from sqlalchemy import or_, func
q = query.strip().upper()
result = await db.execute(
select(Asset).where(
or_(
func.upper(Asset.symbol).contains(q),
func.upper(Asset.name).contains(q),
)
).limit(10)
)
return list(result.scalars().all())

View file

@ -0,0 +1,116 @@
"""
Live price fetching: yfinance for stocks/ETFs, CoinGecko for crypto.
Falls back gracefully never raises, always returns None on failure.
"""
import asyncio
from datetime import date, datetime, timezone, timedelta
from decimal import Decimal
from typing import Any
import structlog
logger = structlog.get_logger()
async def _run_sync(fn, *args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, fn, *args)
def _fetch_yahoo(symbol: str) -> dict | None:
try:
import yfinance as yf
ticker = yf.Ticker(symbol)
info = ticker.fast_info
price = getattr(info, "last_price", None) or getattr(info, "regularMarketPrice", None)
prev = getattr(info, "previous_close", None)
if price is None:
return None
change_24h = None
if prev and prev > 0:
change_24h = round((price - prev) / prev * 100, 4)
return {
"price": Decimal(str(round(price, 8))),
"change_24h": Decimal(str(change_24h)) if change_24h is not None else None,
"currency": (getattr(info, "currency", None) or "USD").upper(),
"name": getattr(info, "long_name", None) or symbol,
"exchange": getattr(info, "exchange", None),
}
except Exception as exc:
logger.warning("yahoo_fetch_failed", symbol=symbol, error=str(exc))
return None
def _fetch_coingecko(coin_id: str) -> dict | None:
try:
import requests
r = requests.get(
f"https://api.coingecko.com/api/v3/simple/price",
params={"ids": coin_id, "vs_currencies": "usd,gbp", "include_24hr_change": "true"},
timeout=10,
)
r.raise_for_status()
data = r.json().get(coin_id, {})
if not data:
return None
return {
"price": Decimal(str(data.get("gbp", data.get("usd", 0)))),
"change_24h": Decimal(str(round(data.get("gbp_24h_change", 0), 4))),
"currency": "GBP",
"name": coin_id,
}
except Exception as exc:
logger.warning("coingecko_fetch_failed", coin_id=coin_id, error=str(exc))
return None
def _fetch_yahoo_history(symbol: str, days: int = 365) -> list[dict]:
try:
import yfinance as yf
ticker = yf.Ticker(symbol)
hist = ticker.history(period=f"{days}d", interval="1d")
rows = []
for ts, row in hist.iterrows():
rows.append({
"date": ts.date(),
"open": Decimal(str(round(float(row["Open"]), 8))),
"high": Decimal(str(round(float(row["High"]), 8))),
"low": Decimal(str(round(float(row["Low"]), 8))),
"close": Decimal(str(round(float(row["Close"]), 8))),
"volume": Decimal(str(int(row.get("Volume", 0) or 0))),
})
return rows
except Exception as exc:
logger.warning("yahoo_history_failed", symbol=symbol, error=str(exc))
return []
async def fetch_price(symbol: str, data_source: str, data_source_id: str | None) -> dict | None:
if data_source == "coingecko":
return await _run_sync(_fetch_coingecko, data_source_id or symbol.lower())
return await _run_sync(_fetch_yahoo, symbol)
async def fetch_history(symbol: str, days: int = 365) -> list[dict]:
return await _run_sync(_fetch_yahoo_history, symbol, days)
def search_yahoo(query: str) -> list[dict]:
try:
import yfinance as yf
ticker = yf.Ticker(query)
info = ticker.fast_info
price = getattr(info, "last_price", None)
if price:
return [{
"symbol": query.upper(),
"name": getattr(info, "long_name", None) or query.upper(),
"type": "stock",
"currency": (getattr(info, "currency", None) or "USD").upper(),
"exchange": getattr(info, "exchange", None),
"data_source": "yahoo_finance",
"data_source_id": None,
}]
except Exception:
pass
return []

View file

@ -0,0 +1,356 @@
import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from dateutil.relativedelta import relativedelta
from sqlalchemy import and_, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.account import Account
from app.db.models.budget import Budget
from app.db.models.category import Category
from app.db.models.net_worth_snapshot import NetWorthSnapshot
from app.db.models.transaction import Transaction
from app.schemas.report import (
BudgetVsActualItem,
BudgetVsActualReport,
CashFlowPoint,
CashFlowReport,
CategoryBreakdownItem,
CategoryBreakdownReport,
IncomeExpensePoint,
IncomeExpenseReport,
NetWorthPoint,
NetWorthReport,
SpendingTrendPoint,
SpendingTrendsReport,
)
LIABILITY_TYPES = {"credit_card", "loan", "mortgage"}
async def _current_net_worth(db: AsyncSession, user_id: uuid.UUID) -> tuple[Decimal, Decimal]:
result = await db.execute(
select(Account).where(
Account.user_id == user_id,
Account.include_in_net_worth == True, # noqa: E712
Account.is_active == True, # noqa: E712
Account.deleted_at.is_(None),
)
)
accounts = result.scalars().all()
assets = Decimal("0")
liabilities = Decimal("0")
for acc in accounts:
bal = acc.current_balance or Decimal("0")
if acc.type in LIABILITY_TYPES:
liabilities += bal
else:
assets += bal
return assets, liabilities
async def get_net_worth_report(
db: AsyncSession, user_id: uuid.UUID, base_currency: str, months: int = 12
) -> NetWorthReport:
cutoff = date.today() - relativedelta(months=months)
result = await db.execute(
select(NetWorthSnapshot)
.where(NetWorthSnapshot.user_id == user_id, NetWorthSnapshot.date >= cutoff)
.order_by(NetWorthSnapshot.date.asc())
)
snapshots = result.scalars().all()
points = [
NetWorthPoint(
date=s.date,
total_assets=s.total_assets,
total_liabilities=s.total_liabilities,
net_worth=s.net_worth,
base_currency=s.base_currency,
)
for s in snapshots
]
assets, liabilities = await _current_net_worth(db, user_id)
current_nw = assets - liabilities
change_30d = Decimal("0")
change_30d_pct = Decimal("0")
if points:
past_nw = points[0].net_worth
change_30d = current_nw - past_nw
if past_nw != 0:
change_30d_pct = (change_30d / abs(past_nw) * 100).quantize(Decimal("0.01"))
return NetWorthReport(
points=points,
current_net_worth=current_nw,
change_30d=change_30d,
change_30d_pct=change_30d_pct,
base_currency=base_currency,
)
async def get_income_expense_report(
db: AsyncSession, user_id: uuid.UUID, months: int = 12
) -> IncomeExpenseReport:
cutoff = (date.today().replace(day=1) - relativedelta(months=months - 1))
result = await db.execute(
text("""
SELECT
TO_CHAR(date, 'YYYY-MM') AS month,
SUM(CASE WHEN type = 'income' THEN amount ELSE 0 END) AS income,
SUM(CASE WHEN type = 'expense' THEN ABS(amount) ELSE 0 END) AS expenses
FROM transactions
WHERE user_id = CAST(:uid AS uuid)
AND status != 'void'
AND deleted_at IS NULL
AND date >= :cutoff
GROUP BY TO_CHAR(date, 'YYYY-MM')
ORDER BY month ASC
""").bindparams(uid=str(user_id), cutoff=cutoff)
)
rows = result.fetchall()
points = []
total_income = Decimal("0")
total_expenses = Decimal("0")
for row in rows:
inc = Decimal(str(row.income or 0))
exp = Decimal(str(row.expenses or 0))
points.append(IncomeExpensePoint(month=row.month, income=inc, expenses=exp, net=inc - exp))
total_income += inc
total_expenses += exp
n = len(points) or 1
return IncomeExpenseReport(
points=points,
total_income=total_income,
total_expenses=total_expenses,
avg_monthly_income=(total_income / n).quantize(Decimal("0.01")),
avg_monthly_expenses=(total_expenses / n).quantize(Decimal("0.01")),
currency="GBP",
)
async def get_cash_flow_report(
db: AsyncSession, user_id: uuid.UUID, date_from: date, date_to: date
) -> CashFlowReport:
result = await db.execute(
text("""
SELECT
date,
SUM(CASE WHEN amount > 0 THEN amount ELSE 0 END) AS inflow,
SUM(CASE WHEN amount < 0 THEN ABS(amount) ELSE 0 END) AS outflow
FROM transactions
WHERE user_id = CAST(:uid AS uuid)
AND status != 'void'
AND deleted_at IS NULL
AND date BETWEEN :df AND :dt
AND type IN ('income', 'expense')
GROUP BY date
ORDER BY date ASC
""").bindparams(uid=str(user_id), df=date_from, dt=date_to)
)
rows = result.fetchall()
points = []
running = Decimal("0")
total_inflow = Decimal("0")
total_outflow = Decimal("0")
for row in rows:
inflow = Decimal(str(row.inflow or 0))
outflow = Decimal(str(row.outflow or 0))
running += inflow - outflow
total_inflow += inflow
total_outflow += outflow
points.append(
CashFlowPoint(
date=row.date,
inflow=inflow,
outflow=outflow,
net=inflow - outflow,
running_balance=running,
)
)
return CashFlowReport(
points=points,
total_inflow=total_inflow,
total_outflow=total_outflow,
currency="GBP",
)
async def get_category_breakdown(
db: AsyncSession,
user_id: uuid.UUID,
date_from: date,
date_to: date,
txn_type: str = "expense",
) -> CategoryBreakdownReport:
result = await db.execute(
select(
Transaction.category_id,
func.sum(func.abs(Transaction.amount)).label("total"),
func.count(Transaction.id).label("cnt"),
)
.where(
Transaction.user_id == user_id,
Transaction.type == txn_type,
Transaction.status != "void",
Transaction.date >= date_from,
Transaction.date <= date_to,
Transaction.deleted_at.is_(None),
)
.group_by(Transaction.category_id)
.order_by(func.sum(func.abs(Transaction.amount)).desc())
)
rows = result.fetchall()
grand_total = Decimal("0")
raw = []
for row in rows:
amt = Decimal(str(row.total or 0))
grand_total += amt
if row.category_id:
cat_result = await db.execute(select(Category).where(Category.id == row.category_id))
category = cat_result.scalar_one_or_none()
cat_name = category.name if category else "Uncategorised"
else:
cat_name = "Uncategorised"
raw.append((row.category_id, cat_name, amt, row.cnt))
items = [
CategoryBreakdownItem(
category_id=str(cat_id) if cat_id else None,
category_name=name,
amount=amt,
percent=(amt / grand_total * 100).quantize(Decimal("0.01")) if grand_total > 0 else Decimal("0"),
transaction_count=cnt,
)
for cat_id, name, amt, cnt in raw
]
return CategoryBreakdownReport(
items=items,
total=grand_total,
currency="GBP",
date_from=date_from,
date_to=date_to,
)
async def get_budget_vs_actual(db: AsyncSession, user_id: uuid.UUID) -> BudgetVsActualReport:
from app.services.budget_service import list_budgets, _period_bounds
today = date.today()
budgets = await list_budgets(db, user_id, active_only=True)
items = []
total_budgeted = Decimal("0")
total_actual = Decimal("0")
for budget in budgets:
period_start, period_end = _period_bounds(budget.period, today)
cat_result = await db.execute(select(Category).where(Category.id == budget.category_id))
category = cat_result.scalar_one_or_none()
cat_name = category.name if category else "Unknown"
spent_result = await db.execute(
select(func.coalesce(func.sum(func.abs(Transaction.amount)), Decimal("0")))
.where(
and_(
Transaction.user_id == user_id,
Transaction.category_id == budget.category_id,
Transaction.type == "expense",
Transaction.status != "void",
Transaction.date >= period_start,
Transaction.date <= period_end,
Transaction.deleted_at.is_(None),
)
)
)
actual = Decimal(str(spent_result.scalar() or 0))
variance = budget.amount - actual
pct = (actual / budget.amount * 100).quantize(Decimal("0.01")) if budget.amount > 0 else Decimal("0")
items.append(
BudgetVsActualItem(
budget_id=str(budget.id),
budget_name=budget.name,
category_name=cat_name,
budgeted=budget.amount,
actual=actual,
variance=variance,
percent_used=pct,
)
)
total_budgeted += budget.amount
total_actual += actual
return BudgetVsActualReport(
items=items,
total_budgeted=total_budgeted,
total_actual=total_actual,
currency="GBP",
)
async def get_spending_trends(
db: AsyncSession, user_id: uuid.UUID, months: int = 6
) -> SpendingTrendsReport:
cutoff = (date.today().replace(day=1) - relativedelta(months=months - 1))
result = await db.execute(
text("""
SELECT
TO_CHAR(t.date, 'YYYY-MM') AS month,
COALESCE(c.name, 'Uncategorised') AS category_name,
SUM(ABS(t.amount)) AS amount
FROM transactions t
LEFT JOIN categories c ON c.id = t.category_id
WHERE t.user_id = CAST(:uid AS uuid)
AND t.type = 'expense'
AND t.status != 'void'
AND t.deleted_at IS NULL
AND t.date >= :cutoff
GROUP BY TO_CHAR(t.date, 'YYYY-MM'), c.name
ORDER BY month ASC, amount DESC
""").bindparams(uid=str(user_id), cutoff=cutoff)
)
rows = result.fetchall()
points = [
SpendingTrendPoint(month=row.month, category_name=row.category_name, amount=Decimal(str(row.amount or 0)))
for row in rows
]
categories = list(dict.fromkeys(p.category_name for p in points))
return SpendingTrendsReport(points=points, categories=categories, currency="GBP")
async def take_net_worth_snapshot(db: AsyncSession, user_id: uuid.UUID, base_currency: str) -> None:
today = date.today()
existing = await db.execute(
select(NetWorthSnapshot).where(
NetWorthSnapshot.user_id == user_id,
NetWorthSnapshot.date == today,
)
)
if existing.scalar_one_or_none():
return
assets, liabilities = await _current_net_worth(db, user_id)
snapshot = NetWorthSnapshot(
id=uuid.uuid4(),
user_id=user_id,
date=today,
total_assets=assets,
total_liabilities=liabilities,
net_worth=assets - liabilities,
base_currency=base_currency,
breakdown={},
created_at=datetime.now(timezone.utc),
)
db.add(snapshot)
await db.flush()

View file

@ -0,0 +1,308 @@
from __future__ import annotations
import hashlib
import uuid
from datetime import datetime, timezone
from decimal import Decimal
from sqlalchemy import and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import decrypt_field, encrypt_field
from app.db.models.transaction import Transaction
from app.schemas.transaction import TransactionCreate, TransactionFilter, TransactionUpdate
from app.services.account_service import recalculate_balance
class TransactionError(Exception):
def __init__(self, detail: str, status_code: int = 400):
self.detail = detail
self.status_code = status_code
def _enc(v: str | None) -> bytes | None:
return encrypt_field(v) if v else None
def _dec(v: bytes | None) -> str | None:
return decrypt_field(v) if v else None
def _to_response(t: Transaction) -> dict:
return {
"id": t.id,
"account_id": t.account_id,
"transfer_account_id": t.transfer_account_id,
"category_id": t.category_id,
"type": t.type,
"status": t.status,
"amount": t.amount,
"amount_base": t.amount_base,
"currency": t.currency,
"base_currency": t.base_currency,
"exchange_rate": t.exchange_rate,
"date": t.date,
"description": _dec(t.description_enc) or "",
"merchant": _dec(t.merchant_enc),
"notes": _dec(t.notes_enc),
"tags": t.tags or [],
"is_recurring": t.is_recurring,
"created_at": t.created_at,
"updated_at": t.updated_at,
}
async def create_transaction(
db: AsyncSession,
user_id: uuid.UUID,
data: TransactionCreate,
base_currency: str,
) -> dict:
now = datetime.now(timezone.utc)
amount = data.amount
# For transfers, create mirrored entry on destination account
txn = Transaction(
user_id=user_id,
account_id=data.account_id,
transfer_account_id=data.transfer_account_id,
category_id=data.category_id,
type=data.type,
status=data.status,
amount=amount,
amount_base=amount, # Phase 3: convert via FX rate
currency=data.currency,
base_currency=base_currency,
exchange_rate=Decimal("1") if data.currency == base_currency else None,
date=data.date,
description_enc=encrypt_field(data.description),
merchant_enc=_enc(data.merchant),
notes_enc=_enc(data.notes),
tags=data.tags,
is_recurring=data.is_recurring,
recurring_rule=data.recurring_rule,
created_at=now,
updated_at=now,
)
db.add(txn)
await db.flush()
# If transfer, create the counter-entry on the destination account
if data.type == "transfer" and data.transfer_account_id:
counter = Transaction(
user_id=user_id,
account_id=data.transfer_account_id,
transfer_account_id=data.account_id,
category_id=data.category_id,
type="transfer",
status=data.status,
amount=-amount, # opposite sign
amount_base=-amount,
currency=data.currency,
base_currency=base_currency,
exchange_rate=Decimal("1") if data.currency == base_currency else None,
date=data.date,
description_enc=encrypt_field(data.description),
merchant_enc=_enc(data.merchant),
notes_enc=_enc(data.notes),
tags=data.tags,
is_recurring=False,
created_at=now,
updated_at=now,
)
db.add(counter)
await db.flush()
await recalculate_balance(db, data.transfer_account_id)
await recalculate_balance(db, data.account_id)
return _to_response(txn)
async def list_transactions(
db: AsyncSession,
user_id: uuid.UUID,
filters: TransactionFilter,
) -> dict:
conditions = [
Transaction.user_id == user_id,
Transaction.deleted_at.is_(None),
]
if filters.account_id:
conditions.append(Transaction.account_id == filters.account_id)
if filters.category_id:
conditions.append(Transaction.category_id == filters.category_id)
if filters.type:
conditions.append(Transaction.type == filters.type)
if filters.status:
conditions.append(Transaction.status == filters.status)
if filters.date_from:
conditions.append(Transaction.date >= filters.date_from)
if filters.date_to:
conditions.append(Transaction.date <= filters.date_to)
if filters.min_amount is not None:
conditions.append(Transaction.amount >= filters.min_amount)
if filters.max_amount is not None:
conditions.append(Transaction.amount <= filters.max_amount)
query = select(Transaction).where(and_(*conditions)).order_by(Transaction.date.desc(), Transaction.created_at.desc())
# Count total
from sqlalchemy import func
count_result = await db.execute(select(func.count()).select_from(query.subquery()))
total = count_result.scalar_one()
# Paginate
offset = (filters.page - 1) * filters.page_size
query = query.offset(offset).limit(filters.page_size)
result = await db.execute(query)
items = [_to_response(t) for t in result.scalars()]
# Filter by search (post-decrypt — Phase 3 will add FTS)
if filters.search:
term = filters.search.lower()
items = [
t for t in items
if term in t["description"].lower()
or (t["merchant"] and term in t["merchant"].lower())
]
return {
"items": items,
"total": total,
"page": filters.page,
"page_size": filters.page_size,
"pages": max(1, -(-total // filters.page_size)),
}
async def get_transaction(db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID) -> Transaction:
result = await db.execute(
select(Transaction).where(
Transaction.id == txn_id,
Transaction.user_id == user_id,
Transaction.deleted_at.is_(None),
)
)
txn = result.scalar_one_or_none()
if not txn:
raise TransactionError("Transaction not found", status_code=404)
return txn
async def update_transaction(
db: AsyncSession,
txn_id: uuid.UUID,
user_id: uuid.UUID,
data: TransactionUpdate,
base_currency: str,
) -> dict:
txn = await get_transaction(db, txn_id, user_id)
now = datetime.now(timezone.utc)
old_account_id = txn.account_id
if data.category_id is not None:
txn.category_id = data.category_id
if data.status is not None:
txn.status = data.status
if data.amount is not None:
txn.amount = data.amount
txn.amount_base = data.amount
if data.date is not None:
txn.date = data.date
if data.description is not None:
txn.description_enc = encrypt_field(data.description)
if data.merchant is not None:
txn.merchant_enc = _enc(data.merchant)
if data.notes is not None:
txn.notes_enc = _enc(data.notes)
if data.tags is not None:
txn.tags = data.tags
txn.updated_at = now
await db.flush()
await recalculate_balance(db, old_account_id)
return _to_response(txn)
async def delete_transaction(db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID) -> None:
txn = await get_transaction(db, txn_id, user_id)
account_id = txn.account_id
txn.deleted_at = datetime.now(timezone.utc)
txn.updated_at = datetime.now(timezone.utc)
await db.flush()
await recalculate_balance(db, account_id)
async def import_csv(
db: AsyncSession,
user_id: uuid.UUID,
account_id: uuid.UUID,
rows: list[dict],
base_currency: str,
) -> dict:
"""
Import transactions from parsed CSV rows.
Each row must have: date, description, amount
Optional: merchant, notes, category_name
Returns counts of imported vs skipped (duplicates).
"""
imported = 0
skipped = 0
now = datetime.now(timezone.utc)
for row in rows:
# Build dedup hash from date + description + amount
raw = f"{row['date']}|{row['description']}|{row['amount']}"
import_hash = hashlib.sha256(raw.encode()).hexdigest()
# Check duplicate
exists = await db.scalar(
select(Transaction.id).where(
Transaction.user_id == user_id,
Transaction.import_hash == import_hash,
)
)
if exists:
skipped += 1
continue
try:
amount = Decimal(str(row["amount"]))
from datetime import date as date_type
import dateutil.parser
txn_date = dateutil.parser.parse(str(row["date"])).date()
except Exception:
skipped += 1
continue
txn_type = "income" if amount > 0 else "expense"
txn = Transaction(
user_id=user_id,
account_id=account_id,
type=txn_type,
status="cleared",
amount=amount,
amount_base=amount,
currency=row.get("currency", base_currency),
base_currency=base_currency,
exchange_rate=Decimal("1"),
date=txn_date,
description_enc=encrypt_field(str(row.get("description", ""))),
merchant_enc=_enc(row.get("merchant")),
notes_enc=_enc(row.get("notes")),
tags=[],
is_recurring=False,
import_hash=import_hash,
created_at=now,
updated_at=now,
)
db.add(txn)
imported += 1
await db.flush()
if imported > 0:
await recalculate_balance(db, account_id)
return {"imported": imported, "skipped": skipped}

View file

View file

@ -0,0 +1,74 @@
import structlog
logger = structlog.get_logger()
PAIRS = [
("GBP", "USD"), ("GBP", "EUR"), ("GBP", "JPY"), ("GBP", "CAD"),
("GBP", "AUD"), ("GBP", "CHF"), ("USD", "GBP"), ("EUR", "GBP"),
]
async def fx_sync_job() -> None:
from app.dependencies import get_session_factory
session_factory = get_session_factory()
if not session_factory:
return
try:
import requests
r = requests.get(
"https://api.exchangerate-api.com/v4/latest/GBP",
timeout=10,
)
r.raise_for_status()
data = r.json()
rates = data.get("rates", {})
except Exception as exc:
logger.error("fx_fetch_failed", error=str(exc))
return
from datetime import datetime, timezone
from decimal import Decimal
import uuid as _uuid
from sqlalchemy import select
from app.db.models.currency import ExchangeRate # type: ignore[attr-defined]
async with session_factory() as db:
try:
now = datetime.now(timezone.utc)
for base, quote in PAIRS:
if base == "GBP":
rate_val = rates.get(quote)
else:
gbp_to_base = rates.get(base)
if not gbp_to_base or gbp_to_base == 0:
continue
rate_val = 1 / gbp_to_base
if not rate_val:
continue
result = await db.execute(
select(ExchangeRate).where(
ExchangeRate.base_currency == base,
ExchangeRate.quote_currency == quote,
)
)
existing = result.scalar_one_or_none()
if existing:
existing.rate = Decimal(str(round(rate_val, 8)))
existing.fetched_at = now
else:
db.add(ExchangeRate(
id=_uuid.uuid4(),
base_currency=base,
quote_currency=quote,
rate=Decimal(str(round(rate_val, 8))),
source="exchangerate-api",
fetched_at=now,
))
await db.commit()
logger.info("fx_sync_done", pairs=len(PAIRS))
except Exception as exc:
await db.rollback()
logger.error("fx_sync_db_failed", error=str(exc))

View file

@ -0,0 +1,31 @@
import structlog
from sqlalchemy import select
logger = structlog.get_logger()
async def price_sync_job() -> None:
from app.dependencies import get_session_factory
from app.db.models.asset import Asset
from app.services.price_feed_service import fetch_price
from app.services.investment_service import update_asset_price
session_factory = get_session_factory()
if not session_factory:
return
async with session_factory() as db:
try:
result = await db.execute(select(Asset).where(Asset.is_active == True)) # noqa: E712
assets = result.scalars().all()
updated = 0
for asset in assets:
data = await fetch_price(asset.symbol, asset.data_source, asset.data_source_id)
if data and data.get("price"):
await update_asset_price(db, asset, data["price"], data.get("change_24h"))
updated += 1
await db.commit()
logger.info("price_sync_done", updated=updated, total=len(assets))
except Exception as exc:
await db.rollback()
logger.error("price_sync_failed", error=str(exc))

View file

@ -0,0 +1,33 @@
"""
APScheduler background jobs. Starts with the FastAPI lifespan.
"""
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
import structlog
logger = structlog.get_logger()
_scheduler: AsyncIOScheduler | None = None
async def start_scheduler() -> None:
global _scheduler
from app.workers.snapshot import snapshot_job
from app.workers.price_sync import price_sync_job
from app.workers.fx_sync import fx_sync_job
_scheduler = AsyncIOScheduler()
_scheduler.add_job(snapshot_job, CronTrigger(hour=2, minute=0), id="nw_snapshot")
_scheduler.add_job(price_sync_job, CronTrigger(minute="*/15"), id="price_sync")
_scheduler.add_job(fx_sync_job, CronTrigger(minute=0), id="fx_sync")
# _scheduler.add_job(backup_job, CronTrigger(hour=3), id="backup")
# _scheduler.add_job(ml_retrain_job, CronTrigger(day_of_week="sun", hour=1), id="ml_retrain")
_scheduler.start()
logger.info("scheduler_started")
async def stop_scheduler() -> None:
if _scheduler and _scheduler.running:
_scheduler.shutdown(wait=False)
logger.info("scheduler_stopped")

View file

@ -0,0 +1,23 @@
import structlog
from sqlalchemy import select
logger = structlog.get_logger()
async def snapshot_job() -> None:
from app.dependencies import get_session_factory
from app.db.models.user import User
from app.services.report_service import take_net_worth_snapshot
session_factory = get_session_factory()
async with session_factory() as db:
try:
result = await db.execute(select(User).where(User.deleted_at.is_(None)))
users = result.scalars().all()
for user in users:
await take_net_worth_snapshot(db, user.id, user.base_currency)
await db.commit()
logger.info("snapshot_job_done", users=len(users))
except Exception as exc:
await db.rollback()
logger.error("snapshot_job_failed", error=str(exc))

55
backend/pyproject.toml Normal file
View file

@ -0,0 +1,55 @@
[project]
name = "finance-tracker"
version = "0.1.0"
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.115",
"uvicorn[standard]>=0.30",
"sqlalchemy[asyncio]>=2.0",
"asyncpg>=0.30",
"alembic>=1.14",
"redis[hiredis]>=5.2",
"pydantic[email]>=2.10",
"pydantic-settings>=2.7",
"argon2-cffi>=23.1",
"python-jose[cryptography]>=3.3",
"pyotp>=2.9",
"qrcode[pil]>=8.0",
"cryptography>=44.0",
"yfinance>=0.2",
"prophet>=1.1",
"statsmodels>=0.14",
"numpy>=2.0",
"scipy>=1.14",
"pandas>=2.2",
"joblib>=1.4",
"apscheduler>=3.10",
"python-multipart>=0.0.12",
"httpx>=0.27",
"python-dateutil>=2.9",
"slowapi>=0.1.9",
"structlog>=24.0",
"pillow>=11.0",
"python-magic>=0.4",
"psycopg2-binary>=2.9",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.24",
"pytest-cov>=6.0",
"httpx>=0.27",
"anyio>=4.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.hatch.build.targets.wheel]
packages = ["app"]