Initial commit: MyMidas personal finance tracker
Full-stack self-hosted finance app with FastAPI backend and React frontend. Features: - Accounts, transactions, budgets, investments with GBP base currency - CSV import with auto-detection for 10 UK bank formats - ML predictions: spending forecast, net worth projection, Monte Carlo - 7 selectable themes (Obsidian, Arctic, Midnight, Vault, Terminal, Synthwave, Ledger) - Receipt/document attachments on transactions (JPEG, PNG, WebP, PDF) - AES-256-GCM field encryption, RS256 JWT, TOTP 2FA, RLS, audit log - Encrypted nightly backups + key rotation script - Mobile-responsive layout with bottom nav Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
61a7884ee5
127 changed files with 13323 additions and 0 deletions
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
14
backend/app/api/router.py
Normal file
14
backend/app/api/router.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import auth, users, accounts, categories, transactions, budgets, reports, investments, predictions
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
router.include_router(accounts.router, prefix="/accounts", tags=["accounts"])
|
||||
router.include_router(categories.router, prefix="/categories", tags=["categories"])
|
||||
router.include_router(transactions.router, prefix="/transactions", tags=["transactions"])
|
||||
router.include_router(budgets.router)
|
||||
router.include_router(reports.router)
|
||||
router.include_router(investments.router)
|
||||
router.include_router(predictions.router)
|
||||
0
backend/app/api/v1/__init__.py
Normal file
0
backend/app/api/v1/__init__.py
Normal file
236
backend/app/api/v1/accounts.py
Normal file
236
backend/app/api/v1/accounts.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.audit import write_audit
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.schemas.account import AccountCreate, AccountResponse, AccountUpdate
|
||||
from app.services.account_service import (
|
||||
AccountError,
|
||||
create_account,
|
||||
delete_account,
|
||||
get_account,
|
||||
get_net_worth,
|
||||
list_accounts,
|
||||
update_account,
|
||||
)
|
||||
|
||||
MAX_IMPORT_FILE_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
MAX_IMPORT_ROWS = 50_000
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=list[AccountResponse])
|
||||
async def get_accounts(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return await list_accounts(db, user.id)
|
||||
|
||||
|
||||
@router.post("", response_model=AccountResponse, status_code=201)
|
||||
async def create(
|
||||
body: AccountCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
result = await create_account(db, user.id, body)
|
||||
await write_audit(db, user_id=user.id, action="account_create")
|
||||
await db.commit()
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/net-worth")
|
||||
async def net_worth(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return await get_net_worth(db, user.id, user.base_currency)
|
||||
|
||||
|
||||
@router.get("/{account_id}", response_model=AccountResponse)
|
||||
async def get_one(
|
||||
account_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
account = await get_account(db, account_id, user.id)
|
||||
from app.services.account_service import _to_response
|
||||
return _to_response(account)
|
||||
except AccountError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
|
||||
@router.put("/{account_id}", response_model=AccountResponse)
|
||||
async def update(
|
||||
account_id: uuid.UUID,
|
||||
body: AccountUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
result = await update_account(db, account_id, user.id, body)
|
||||
await write_audit(db, user_id=user.id, action="account_update", resource_type="account", resource_id=account_id)
|
||||
await db.commit()
|
||||
return result
|
||||
except AccountError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
|
||||
@router.post("/{account_id}/import/preview")
|
||||
async def import_preview(
|
||||
account_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Upload a CSV and get back the detected format, column mapping, and a sample of parsed rows."""
|
||||
from app.services.csv_detector import parse_csv_content, detect_format
|
||||
|
||||
try:
|
||||
await get_account(db, account_id, user.id)
|
||||
except AccountError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
content = await file.read(MAX_IMPORT_FILE_BYTES + 1)
|
||||
if len(content) > MAX_IMPORT_FILE_BYTES:
|
||||
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
|
||||
try:
|
||||
headers, rows = parse_csv_content(content)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if not headers:
|
||||
raise HTTPException(status_code=400, detail="Could not read CSV headers")
|
||||
|
||||
mapping = detect_format(headers)
|
||||
|
||||
# Build 5-row preview using the detected mapping
|
||||
preview = []
|
||||
for row in rows[:5]:
|
||||
entry: dict = {
|
||||
"date_raw": row.get(mapping.date, ""),
|
||||
"description_raw": row.get(mapping.description, ""),
|
||||
}
|
||||
if mapping.is_split():
|
||||
debit_str = row.get(mapping.debit or "", "").replace(",", "").replace("£", "").strip()
|
||||
credit_str = row.get(mapping.credit or "", "").replace(",", "").replace("£", "").strip()
|
||||
try:
|
||||
debit = float(debit_str) if debit_str else 0.0
|
||||
credit = float(credit_str) if credit_str else 0.0
|
||||
entry["amount_raw"] = credit - debit
|
||||
except ValueError:
|
||||
entry["amount_raw"] = None
|
||||
else:
|
||||
raw = row.get(mapping.amount or "", "").replace(",", "").replace("£", "").strip()
|
||||
try:
|
||||
entry["amount_raw"] = float(raw) if raw else None
|
||||
except ValueError:
|
||||
entry["amount_raw"] = None
|
||||
if mapping.balance:
|
||||
entry["balance_raw"] = row.get(mapping.balance, "")
|
||||
preview.append(entry)
|
||||
|
||||
return {
|
||||
"detected_format": mapping.detected_format,
|
||||
"headers": headers,
|
||||
"mapping": {
|
||||
"date": mapping.date,
|
||||
"description": mapping.description,
|
||||
"amount": mapping.amount,
|
||||
"debit": mapping.debit,
|
||||
"credit": mapping.credit,
|
||||
"balance": mapping.balance,
|
||||
"reference": mapping.reference,
|
||||
},
|
||||
"total_rows": len(rows),
|
||||
"preview": preview,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{account_id}/import")
|
||||
async def import_csv_to_account(
|
||||
account_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
date_col: str = Form(...),
|
||||
description_col: str = Form(...),
|
||||
amount_col: str = Form(default=""),
|
||||
debit_col: str = Form(default=""),
|
||||
credit_col: str = Form(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
from app.services.csv_detector import parse_csv_content
|
||||
from app.services.transaction_service import import_csv
|
||||
from app.core.audit import write_audit
|
||||
|
||||
try:
|
||||
await get_account(db, account_id, user.id)
|
||||
except AccountError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
content = await file.read(MAX_IMPORT_FILE_BYTES + 1)
|
||||
if len(content) > MAX_IMPORT_FILE_BYTES:
|
||||
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
|
||||
try:
|
||||
_, rows = parse_csv_content(content)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if len(rows) > MAX_IMPORT_ROWS:
|
||||
raise HTTPException(status_code=400, detail=f"File contains too many rows (max {MAX_IMPORT_ROWS:,})")
|
||||
|
||||
use_split = bool(debit_col and credit_col)
|
||||
parsed_rows = []
|
||||
|
||||
for row in rows:
|
||||
date_val = row.get(date_col, "").strip()
|
||||
desc_val = row.get(description_col, "").strip() or "Imported transaction"
|
||||
|
||||
if use_split:
|
||||
debit_str = row.get(debit_col, "").replace(",", "").replace("£", "").strip()
|
||||
credit_str = row.get(credit_col, "").replace(",", "").replace("£", "").strip()
|
||||
try:
|
||||
debit = float(debit_str) if debit_str else 0.0
|
||||
credit = float(credit_str) if credit_str else 0.0
|
||||
amount = credit - debit
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
raw = row.get(amount_col, "").replace(",", "").replace("£", "").strip()
|
||||
try:
|
||||
amount = float(raw) if raw else None
|
||||
except ValueError:
|
||||
continue
|
||||
if amount is None:
|
||||
continue
|
||||
|
||||
if not date_val:
|
||||
continue
|
||||
|
||||
parsed_rows.append({"date": date_val, "description": desc_val, "amount": str(amount)})
|
||||
|
||||
if not parsed_rows:
|
||||
raise HTTPException(status_code=400, detail="No valid rows found after applying column mapping")
|
||||
|
||||
result = await import_csv(db, user.id, account_id, parsed_rows, user.base_currency)
|
||||
await write_audit(db, user_id=user.id, action="import_data", metadata=result)
|
||||
await db.commit()
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{account_id}", status_code=204)
|
||||
async def delete(
|
||||
account_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
await delete_account(db, account_id, user.id)
|
||||
await write_audit(db, user_id=user.id, action="account_delete", resource_type="account", resource_id=account_id)
|
||||
await db.commit()
|
||||
except AccountError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
342
backend/app/api/v1/auth.py
Normal file
342
backend/app/api/v1/auth.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
"""
|
||||
Auth endpoints: register, login, TOTP, refresh, logout, sessions.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.audit import write_audit
|
||||
from app.core.rate_limiter import is_rate_limited
|
||||
from app.core.security import create_refresh_token, decode_token, generate_csrf_token, hash_token
|
||||
from app.dependencies import get_current_user, get_db, get_redis
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
RegisterRequest,
|
||||
SessionInfo,
|
||||
TOTPChallengeResponse,
|
||||
TOTPLoginRequest,
|
||||
TOTPSetupResponse,
|
||||
TOTPVerifyRequest,
|
||||
TokenResponse,
|
||||
)
|
||||
from app.services.auth_service import (
|
||||
AuthError,
|
||||
authenticate_user,
|
||||
complete_totp_login,
|
||||
create_totp_challenge_token,
|
||||
disable_totp,
|
||||
enable_totp,
|
||||
get_sessions,
|
||||
register_user,
|
||||
revoke_all_sessions,
|
||||
revoke_session,
|
||||
setup_totp,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _ip(request: Request) -> str | None:
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else None
|
||||
|
||||
|
||||
def _ua(request: Request) -> str | None:
|
||||
return request.headers.get("User-Agent")
|
||||
|
||||
|
||||
def _set_refresh_cookie(response: Response, token: str) -> None:
|
||||
response.set_cookie(
|
||||
"refresh_token",
|
||||
token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
max_age=7 * 24 * 3600,
|
||||
path="/api/v1/auth",
|
||||
)
|
||||
|
||||
|
||||
def _set_csrf_cookie(response: Response, token: str) -> None:
|
||||
response.set_cookie(
|
||||
"csrf_token",
|
||||
token,
|
||||
httponly=False,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
max_age=86400,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", status_code=201)
|
||||
async def register(
|
||||
body: RegisterRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
user = await register_user(db, body.email, body.password, body.display_name)
|
||||
await write_audit(db, user_id=user.id, action="register", ip_address=_ip(request))
|
||||
await db.commit()
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
return {"message": "Account created. Please log in."}
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
body: LoginRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
try:
|
||||
user, access_token, refresh_token = await authenticate_user(
|
||||
db, redis, body.email, body.password, _ip(request), _ua(request)
|
||||
)
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
if access_token is None:
|
||||
# TOTP required
|
||||
challenge_token = create_totp_challenge_token(user.id)
|
||||
await write_audit(db, user_id=user.id, action="login", ip_address=_ip(request), metadata={"totp_required": True})
|
||||
await db.commit()
|
||||
return TOTPChallengeResponse(challenge_token=challenge_token)
|
||||
|
||||
csrf = generate_csrf_token()
|
||||
_set_refresh_cookie(response, refresh_token)
|
||||
_set_csrf_cookie(response, csrf)
|
||||
await write_audit(db, user_id=user.id, action="login", ip_address=_ip(request))
|
||||
await db.commit()
|
||||
|
||||
settings_expire = 15 * 60
|
||||
return TokenResponse(access_token=access_token, expires_in=settings_expire)
|
||||
|
||||
|
||||
@router.post("/login/totp")
|
||||
async def login_totp(
|
||||
body: TOTPLoginRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
ip = _ip(request) or "unknown"
|
||||
limited, _ = await is_rate_limited(redis, f"rate:totp:{ip}", limit=10, window_seconds=60)
|
||||
if limited:
|
||||
raise HTTPException(status_code=429, detail="Too many TOTP attempts — try again shortly")
|
||||
|
||||
try:
|
||||
access_token, refresh_token = await complete_totp_login(
|
||||
db, body.challenge_token, body.totp_code, _ip(request), _ua(request)
|
||||
)
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
csrf = generate_csrf_token()
|
||||
_set_refresh_cookie(response, refresh_token)
|
||||
_set_csrf_cookie(response, csrf)
|
||||
await db.commit()
|
||||
|
||||
return TokenResponse(access_token=access_token, expires_in=15 * 60)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
token = request.cookies.get("refresh_token")
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="No refresh token")
|
||||
|
||||
try:
|
||||
payload = decode_token(token, token_type="refresh")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
import uuid
|
||||
from app.core.security import create_access_token
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime, timezone
|
||||
from app.db.models.session import Session
|
||||
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Find and update session
|
||||
result = await db.execute(
|
||||
select(Session).where(
|
||||
Session.user_id == user_id,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.expires_at > now,
|
||||
)
|
||||
)
|
||||
session = result.scalars().first()
|
||||
if not session:
|
||||
raise HTTPException(status_code=401, detail="Session not found")
|
||||
|
||||
new_access = create_access_token(str(user_id))
|
||||
new_refresh = create_refresh_token(str(user_id))
|
||||
|
||||
# Rotate session token hash
|
||||
session.token_hash = hash_token(new_access)
|
||||
session.last_active_at = now
|
||||
await db.commit()
|
||||
|
||||
csrf = generate_csrf_token()
|
||||
_set_refresh_cookie(response, new_refresh)
|
||||
_set_csrf_cookie(response, csrf)
|
||||
return TokenResponse(access_token=new_access, expires_in=15 * 60)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
token = request.headers.get("Authorization", "")[7:]
|
||||
th = hash_token(token)
|
||||
await revoke_session_by_hash(db, th, user.id)
|
||||
await write_audit(db, user_id=user.id, action="logout", ip_address=_ip(request))
|
||||
await db.commit()
|
||||
response.delete_cookie("refresh_token", path="/api/v1/auth")
|
||||
response.delete_cookie("csrf_token")
|
||||
return {"message": "Logged out"}
|
||||
|
||||
|
||||
async def revoke_session_by_hash(db, token_hash: str, user_id):
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime, timezone
|
||||
from app.db.models.session import Session
|
||||
await db.execute(
|
||||
update(Session)
|
||||
.where(Session.user_id == user_id, Session.token_hash == token_hash)
|
||||
.values(revoked_at=datetime.now(timezone.utc))
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout-all")
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
token = request.headers.get("Authorization", "")[7:]
|
||||
await revoke_all_sessions(db, user.id)
|
||||
await write_audit(db, user_id=user.id, action="logout_all", ip_address=_ip(request))
|
||||
await db.commit()
|
||||
response.delete_cookie("refresh_token", path="/api/v1/auth")
|
||||
response.delete_cookie("csrf_token")
|
||||
return {"message": "All sessions revoked"}
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[SessionInfo])
|
||||
async def list_sessions(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
token = request.headers.get("Authorization", "")[7:]
|
||||
current_hash = hash_token(token)
|
||||
sessions = await get_sessions(db, user.id)
|
||||
result = []
|
||||
for s in sessions:
|
||||
info = SessionInfo.model_validate(s)
|
||||
info.is_current = (s.token_hash == current_hash)
|
||||
result.append(info)
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}", status_code=204)
|
||||
async def delete_session(
|
||||
session_id,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
import uuid
|
||||
try:
|
||||
sid = uuid.UUID(str(session_id))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid session ID")
|
||||
try:
|
||||
await revoke_session(db, sid, user.id)
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
await write_audit(db, user_id=user.id, action="session_revoke", resource_type="session", resource_id=sid)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.get("/totp/setup", response_model=TOTPSetupResponse)
|
||||
async def totp_setup(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
secret, qr_b64, backup_codes = await setup_totp(user, db)
|
||||
return TOTPSetupResponse(secret=secret, qr_code_png_b64=qr_b64, backup_codes=backup_codes)
|
||||
|
||||
|
||||
@router.post("/totp/verify", status_code=200)
|
||||
async def totp_verify(
|
||||
body: TOTPVerifyRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
# Secret must be passed back from setup — here we expect it stored temporarily in body
|
||||
# In practice the client stores it until verification; it's never persisted until verified
|
||||
# This endpoint receives the secret + verification code
|
||||
# For simplicity we accept: {"secret": "...", "code": "..."}
|
||||
# Redefine body inline:
|
||||
raise HTTPException(status_code=400, detail="Use /totp/enable endpoint with secret and code")
|
||||
|
||||
|
||||
@router.post("/totp/enable", status_code=200)
|
||||
async def totp_enable(
|
||||
body: dict,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
secret = body.get("secret")
|
||||
code = body.get("code")
|
||||
if not secret or not code:
|
||||
raise HTTPException(status_code=422, detail="secret and code required")
|
||||
try:
|
||||
await enable_totp(user, db, secret, code)
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
await write_audit(db, user_id=user.id, action="totp_enable", ip_address=_ip(request))
|
||||
await db.commit()
|
||||
return {"message": "TOTP enabled"}
|
||||
|
||||
|
||||
@router.delete("/totp", status_code=200)
|
||||
async def totp_disable(
|
||||
body: dict,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
password = body.get("password")
|
||||
if not password:
|
||||
raise HTTPException(status_code=422, detail="password required")
|
||||
try:
|
||||
await disable_totp(user, db, password)
|
||||
except AuthError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
await write_audit(db, user_id=user.id, action="totp_disable", ip_address=_ip(request))
|
||||
await db.commit()
|
||||
return {"message": "TOTP disabled"}
|
||||
79
backend/app/api/v1/budgets.py
Normal file
79
backend/app/api/v1/budgets.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.db.models.user import User
|
||||
from app.schemas.budget import BudgetCreate, BudgetResponse, BudgetSummaryItem, BudgetUpdate
|
||||
from app.services import budget_service
|
||||
|
||||
router = APIRouter(prefix="/budgets", tags=["budgets"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[BudgetResponse])
|
||||
async def list_budgets(
|
||||
active_only: bool = True,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await budget_service.list_budgets(db, current_user.id, active_only)
|
||||
|
||||
|
||||
@router.post("", response_model=BudgetResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_budget(
|
||||
data: BudgetCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
budget = await budget_service.create_budget(db, current_user.id, data)
|
||||
await db.commit()
|
||||
return budget
|
||||
|
||||
|
||||
@router.get("/summary", response_model=list[BudgetSummaryItem])
|
||||
async def get_budget_summary(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await budget_service.get_budget_summary(db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/{budget_id}", response_model=BudgetResponse)
|
||||
async def get_budget(
|
||||
budget_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
budget = await budget_service.get_budget(db, current_user.id, budget_id)
|
||||
if not budget:
|
||||
raise HTTPException(status_code=404, detail="Budget not found")
|
||||
return budget
|
||||
|
||||
|
||||
@router.put("/{budget_id}", response_model=BudgetResponse)
|
||||
async def update_budget(
|
||||
budget_id: uuid.UUID,
|
||||
data: BudgetUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
budget = await budget_service.get_budget(db, current_user.id, budget_id)
|
||||
if not budget:
|
||||
raise HTTPException(status_code=404, detail="Budget not found")
|
||||
budget = await budget_service.update_budget(db, budget, data)
|
||||
await db.commit()
|
||||
return budget
|
||||
|
||||
|
||||
@router.delete("/{budget_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_budget(
|
||||
budget_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
budget = await budget_service.get_budget(db, current_user.id, budget_id)
|
||||
if not budget:
|
||||
raise HTTPException(status_code=404, detail="Budget not found")
|
||||
await budget_service.delete_budget(db, budget)
|
||||
await db.commit()
|
||||
36
backend/app/api/v1/categories.py
Normal file
36
backend/app/api/v1/categories.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.services.category_service import create_category, list_categories
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_categories(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return await list_categories(db, user.id)
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
async def create(
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
result = await create_category(
|
||||
db,
|
||||
user_id=user.id,
|
||||
name=body["name"],
|
||||
type_=body["type"],
|
||||
icon=body.get("icon"),
|
||||
color=body.get("color"),
|
||||
parent_id=uuid.UUID(body["parent_id"]) if body.get("parent_id") else None,
|
||||
)
|
||||
await db.commit()
|
||||
return result
|
||||
199
backend/app/api/v1/investments.py
Normal file
199
backend/app/api/v1/investments.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
import uuid
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.db.models.user import User
|
||||
from app.schemas.investment import (
|
||||
AssetSearch,
|
||||
AssetPricePoint,
|
||||
HoldingCreate,
|
||||
HoldingResponse,
|
||||
InvestmentTxnCreate,
|
||||
InvestmentTxnResponse,
|
||||
PerformanceMetrics,
|
||||
PortfolioSummary,
|
||||
)
|
||||
from app.services import investment_service
|
||||
from app.services.price_feed_service import search_yahoo, fetch_history
|
||||
|
||||
router = APIRouter(tags=["investments"])
|
||||
|
||||
|
||||
# ── Portfolio ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/investments/portfolio", response_model=PortfolioSummary)
|
||||
async def get_portfolio(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await investment_service.get_portfolio(db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/investments/performance", response_model=PerformanceMetrics)
|
||||
async def get_performance(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await investment_service.get_performance(db, current_user.id)
|
||||
|
||||
|
||||
# ── Holdings ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/investments/holdings", response_model=HoldingResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_holding(
|
||||
data: HoldingCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.db.models.asset import Asset
|
||||
from sqlalchemy import select
|
||||
asset_result = await db.execute(select(Asset).where(Asset.id == data.asset_id))
|
||||
asset = asset_result.scalar_one_or_none()
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
holding = await investment_service.create_holding(db, current_user.id, data)
|
||||
await db.commit()
|
||||
await db.refresh(holding)
|
||||
return investment_service._holding_to_response(holding, asset)
|
||||
|
||||
|
||||
@router.delete("/investments/holdings/{holding_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_holding(
|
||||
holding_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
holding = await investment_service.get_holding(db, current_user.id, holding_id)
|
||||
if not holding:
|
||||
raise HTTPException(status_code=404, detail="Holding not found")
|
||||
await db.delete(holding)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ── Investment transactions ────────────────────────────────────────────────
|
||||
|
||||
@router.get("/investments/holdings/{holding_id}/transactions", response_model=list[InvestmentTxnResponse])
|
||||
async def list_transactions(
|
||||
holding_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await investment_service.list_investment_transactions(db, current_user.id, holding_id)
|
||||
|
||||
|
||||
@router.post("/investments/transactions", response_model=InvestmentTxnResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def add_transaction(
|
||||
data: InvestmentTxnCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
txn = await investment_service.add_investment_transaction(db, current_user.id, data)
|
||||
await db.commit()
|
||||
return txn
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# ── Assets ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/assets/search", response_model=list[AssetSearch])
|
||||
async def search_assets(
|
||||
q: str = Query(..., min_length=1),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
# First search the local DB
|
||||
local = await investment_service.search_assets(db, q)
|
||||
if local:
|
||||
from app.db.models.asset import Asset
|
||||
return [AssetSearch(
|
||||
id=a.id, symbol=a.symbol, name=a.name, type=a.type,
|
||||
currency=a.currency, exchange=a.exchange,
|
||||
last_price=a.last_price, price_change_24h=a.price_change_24h,
|
||||
data_source=a.data_source,
|
||||
) for a in local]
|
||||
|
||||
# Fall back to live Yahoo search
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
results = await loop.run_in_executor(None, search_yahoo, q)
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Upsert into DB so future searches are local
|
||||
created = []
|
||||
for r in results:
|
||||
asset = await investment_service.get_or_create_asset(
|
||||
db, r["symbol"], r["name"], r["type"],
|
||||
r["currency"], r["data_source"], r.get("data_source_id"),
|
||||
r.get("exchange"),
|
||||
)
|
||||
created.append(asset)
|
||||
await db.commit()
|
||||
|
||||
return [AssetSearch(
|
||||
id=a.id, symbol=a.symbol, name=a.name, type=a.type,
|
||||
currency=a.currency, exchange=a.exchange,
|
||||
last_price=a.last_price, price_change_24h=a.price_change_24h,
|
||||
data_source=a.data_source,
|
||||
) for a in created]
|
||||
|
||||
|
||||
@router.get("/assets/{asset_id}/prices", response_model=list[AssetPricePoint])
|
||||
async def get_price_history(
|
||||
asset_id: uuid.UUID,
|
||||
days: int = Query(default=365, ge=7, le=1825),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
from app.db.models.asset import Asset
|
||||
from sqlalchemy import select
|
||||
asset_result = await db.execute(select(Asset).where(Asset.id == asset_id))
|
||||
asset = asset_result.scalar_one_or_none()
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
# Fetch from DB; if sparse, refresh from Yahoo
|
||||
prices = await investment_service.get_price_history(db, asset_id, days)
|
||||
if len(prices) < 5 and asset.data_source == "yahoo_finance":
|
||||
rows = await fetch_history(asset.symbol, days)
|
||||
if rows:
|
||||
await investment_service.upsert_price_history(db, asset_id, rows)
|
||||
await db.commit()
|
||||
prices = await investment_service.get_price_history(db, asset_id, days)
|
||||
|
||||
return [
|
||||
AssetPricePoint(
|
||||
date=p.date, open=p.open, high=p.high, low=p.low,
|
||||
close=p.close, volume=p.volume,
|
||||
)
|
||||
for p in prices
|
||||
]
|
||||
|
||||
|
||||
@router.post("/assets", response_model=AssetSearch, status_code=status.HTTP_201_CREATED)
|
||||
async def create_asset(
|
||||
symbol: str,
|
||||
name: str,
|
||||
asset_type: str = "stock",
|
||||
currency: str = "GBP",
|
||||
data_source: str = "yahoo_finance",
|
||||
data_source_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
asset = await investment_service.get_or_create_asset(
|
||||
db, symbol, name, asset_type, currency, data_source, data_source_id
|
||||
)
|
||||
await db.commit()
|
||||
return AssetSearch(
|
||||
id=asset.id, symbol=asset.symbol, name=asset.name, type=asset.type,
|
||||
currency=asset.currency, exchange=asset.exchange,
|
||||
last_price=asset.last_price, price_change_24h=asset.price_change_24h,
|
||||
data_source=asset.data_source,
|
||||
)
|
||||
236
backend/app/api/v1/predictions.py
Normal file
236
backend/app/api/v1/predictions.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
import calendar
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.rate_limiter import is_rate_limited
|
||||
from app.dependencies import get_current_user, get_db, get_redis
|
||||
from app.ml.feature_engineering import (
|
||||
get_monthly_category_spending,
|
||||
get_monthly_net_worth,
|
||||
get_current_month_spending,
|
||||
get_portfolio_monthly_returns,
|
||||
get_daily_cash_flow,
|
||||
)
|
||||
from app.ml.spending_forecast import forecast_spending
|
||||
from app.ml.net_worth_projection import project_net_worth
|
||||
from app.ml.monte_carlo import run_monte_carlo
|
||||
|
||||
router = APIRouter(prefix="/predictions", tags=["predictions"])
|
||||
|
||||
|
||||
async def _check_prediction_rate(redis: Redis, user_id: str) -> None:
|
||||
limited, _ = await is_rate_limited(redis, f"rate:pred:{user_id}", limit=20, window_seconds=60)
|
||||
if limited:
|
||||
raise HTTPException(status_code=429, detail="Too many prediction requests — try again shortly")
|
||||
|
||||
|
||||
@router.get("/spending")
|
||||
async def spending_forecast(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
await _check_prediction_rate(redis, str(user.id))
|
||||
df = await get_monthly_category_spending(db, user.id)
|
||||
categories = forecast_spending(df)
|
||||
return {"categories": categories}
|
||||
|
||||
|
||||
@router.get("/net-worth")
|
||||
async def net_worth_projection(
|
||||
years: int = 5,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
await _check_prediction_rate(redis, str(user.id))
|
||||
years = max(1, min(10, years))
|
||||
df = await get_monthly_net_worth(db, user.id)
|
||||
result = project_net_worth(df, years=years)
|
||||
return result
|
||||
|
||||
|
||||
class MonteCarloRequest(BaseModel):
|
||||
years: int = 5
|
||||
n_simulations: int = 1000
|
||||
annual_contribution: float = 0.0
|
||||
|
||||
|
||||
@router.post("/monte-carlo")
|
||||
async def monte_carlo(
|
||||
body: MonteCarloRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
await _check_prediction_rate(redis, str(user.id))
|
||||
years = max(1, min(10, body.years))
|
||||
n_sims = max(100, min(5000, body.n_simulations))
|
||||
|
||||
# Get portfolio holdings
|
||||
result = await db.execute(text("""
|
||||
SELECT h.id, a.symbol, h.quantity::float, a.last_price::float,
|
||||
(h.quantity * COALESCE(a.last_price, h.avg_cost_basis))::float AS current_value,
|
||||
h.currency
|
||||
FROM investment_holdings h
|
||||
JOIN assets a ON a.id = h.asset_id
|
||||
WHERE h.user_id = CAST(:uid AS uuid)
|
||||
AND h.deleted_at IS NULL
|
||||
AND h.quantity > 0
|
||||
"""), {"uid": str(user.id)})
|
||||
holdings = [
|
||||
{"symbol": r[1], "quantity": r[2], "last_price": r[3], "current_value": r[4]}
|
||||
for r in result.fetchall()
|
||||
]
|
||||
|
||||
prices_df = await get_portfolio_monthly_returns(db, user.id)
|
||||
|
||||
mc = run_monte_carlo(
|
||||
prices_df=prices_df,
|
||||
holdings=holdings,
|
||||
years=years,
|
||||
n_sims=n_sims,
|
||||
annual_contribution=body.annual_contribution,
|
||||
)
|
||||
return mc
|
||||
|
||||
|
||||
@router.get("/budget-forecast")
|
||||
async def budget_forecast(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
await _check_prediction_rate(redis, str(user.id))
|
||||
today = date.today()
|
||||
days_in_month = calendar.monthrange(today.year, today.month)[1]
|
||||
day_of_month = today.day
|
||||
days_remaining = days_in_month - day_of_month
|
||||
|
||||
# Get budgets
|
||||
bgt_result = await db.execute(text("""
|
||||
SELECT b.id::text, COALESCE(c.name, 'Uncategorised') AS cat_name,
|
||||
b.category_id::text, b.amount::float
|
||||
FROM budgets b
|
||||
LEFT JOIN categories c ON c.id = b.category_id
|
||||
WHERE b.user_id = CAST(:uid AS uuid)
|
||||
AND b.period = 'monthly'
|
||||
AND (b.end_date IS NULL OR b.end_date >= CURRENT_DATE)
|
||||
"""), {"uid": str(user.id)})
|
||||
budgets = {r[2]: {"budget_id": r[0], "category_name": r[1], "amount": r[3]} for r in bgt_result.fetchall()}
|
||||
|
||||
if not budgets:
|
||||
return {"forecasts": [], "message": "No monthly budgets set"}
|
||||
|
||||
# Get current month spending per category
|
||||
spent_df = await get_current_month_spending(db, user.id)
|
||||
spent_map = {row["category_id"]: row["spent"] for _, row in spent_df.iterrows()}
|
||||
|
||||
forecasts = []
|
||||
for cat_id, bgt in budgets.items():
|
||||
spent = spent_map.get(cat_id, 0.0)
|
||||
budget_amt = bgt["amount"]
|
||||
|
||||
# Daily velocity
|
||||
velocity = spent / day_of_month if day_of_month > 0 else 0.0
|
||||
forecast_total = spent + velocity * days_remaining
|
||||
|
||||
# Probability of overspend using a rough normal distribution
|
||||
# Assume uncertainty grows with days remaining
|
||||
import math
|
||||
sigma = velocity * math.sqrt(days_remaining) * 0.3 if velocity > 0 else 1.0
|
||||
if sigma > 0:
|
||||
z = (budget_amt - forecast_total) / sigma
|
||||
# CDF of normal
|
||||
import scipy.stats
|
||||
prob_overspend = float(1 - scipy.stats.norm.cdf(z))
|
||||
else:
|
||||
prob_overspend = 1.0 if forecast_total > budget_amt else 0.0
|
||||
|
||||
forecasts.append({
|
||||
"category_id": cat_id,
|
||||
"category_name": bgt["category_name"],
|
||||
"budget_amount": round(budget_amt, 2),
|
||||
"spent_so_far": round(spent, 2),
|
||||
"forecast_month_total": round(max(spent, forecast_total), 2),
|
||||
"daily_velocity": round(velocity, 2),
|
||||
"probability_overspend": round(prob_overspend, 3),
|
||||
"days_remaining": days_remaining,
|
||||
})
|
||||
|
||||
forecasts.sort(key=lambda x: x["probability_overspend"], reverse=True)
|
||||
return {"forecasts": forecasts}
|
||||
|
||||
|
||||
@router.get("/cashflow")
|
||||
async def cashflow_forecast(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
await _check_prediction_rate(redis, str(user.id))
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
|
||||
# Historical daily cash flows (last 90 days)
|
||||
hist_df = await get_daily_cash_flow(db, user.id, days=90)
|
||||
|
||||
# Get current account balances
|
||||
acct_result = await db.execute(text("""
|
||||
SELECT SUM(
|
||||
CASE WHEN type IN ('credit_card','loan','mortgage') THEN -ABS(current_balance)
|
||||
ELSE current_balance END
|
||||
)::float AS total_balance
|
||||
FROM accounts
|
||||
WHERE user_id = CAST(:uid AS uuid)
|
||||
AND is_active = TRUE
|
||||
AND include_in_net_worth = TRUE
|
||||
AND deleted_at IS NULL
|
||||
"""), {"uid": str(user.id)})
|
||||
row = acct_result.fetchone()
|
||||
current_balance = float(row[0] or 0.0)
|
||||
|
||||
# Compute average daily inflow / outflow from history
|
||||
if not hist_df.empty:
|
||||
avg_inflow = float(hist_df["inflow"].mean())
|
||||
avg_outflow = float(hist_df["outflow"].mean())
|
||||
std_net = float((hist_df["inflow"] - hist_df["outflow"]).std())
|
||||
else:
|
||||
avg_inflow = 0.0
|
||||
avg_outflow = 0.0
|
||||
std_net = 0.0
|
||||
|
||||
# Project 30 days forward
|
||||
today = date.today()
|
||||
daily = []
|
||||
running_balance = current_balance
|
||||
for i in range(1, 31):
|
||||
d = today + timedelta(days=i)
|
||||
net = avg_inflow - avg_outflow
|
||||
running_balance += net
|
||||
daily.append({
|
||||
"date": d.strftime("%Y-%m-%d"),
|
||||
"balance": round(running_balance, 2),
|
||||
"avg_inflow": round(avg_inflow, 2),
|
||||
"avg_outflow": round(avg_outflow, 2),
|
||||
"negative_risk": running_balance < 0,
|
||||
})
|
||||
|
||||
negative_days = [d["date"] for d in daily if d["negative_risk"]]
|
||||
|
||||
return {
|
||||
"current_balance": round(current_balance, 2),
|
||||
"avg_daily_inflow": round(avg_inflow, 2),
|
||||
"avg_daily_outflow": round(avg_outflow, 2),
|
||||
"forecast": daily,
|
||||
"negative_risk_days": negative_days,
|
||||
"history_days": len(hist_df),
|
||||
}
|
||||
82
backend/app/api/v1/reports.py
Normal file
82
backend/app/api/v1/reports.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.db.models.user import User
|
||||
from app.schemas.report import (
|
||||
BudgetVsActualReport,
|
||||
CashFlowReport,
|
||||
CategoryBreakdownReport,
|
||||
IncomeExpenseReport,
|
||||
NetWorthReport,
|
||||
SpendingTrendsReport,
|
||||
)
|
||||
from app.services import report_service
|
||||
|
||||
router = APIRouter(prefix="/reports", tags=["reports"])
|
||||
|
||||
|
||||
@router.get("/net-worth", response_model=NetWorthReport)
|
||||
async def net_worth_report(
|
||||
months: int = Query(default=12, ge=1, le=60),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await report_service.get_net_worth_report(
|
||||
db, current_user.id, current_user.base_currency, months
|
||||
)
|
||||
|
||||
|
||||
@router.get("/income-vs-expense", response_model=IncomeExpenseReport)
|
||||
async def income_expense_report(
|
||||
months: int = Query(default=12, ge=1, le=60),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await report_service.get_income_expense_report(db, current_user.id, months)
|
||||
|
||||
|
||||
@router.get("/cash-flow", response_model=CashFlowReport)
|
||||
async def cash_flow_report(
|
||||
date_from: date = Query(default=None),
|
||||
date_to: date = Query(default=None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
today = date.today()
|
||||
df = date_from or (today - timedelta(days=30))
|
||||
dt = date_to or today
|
||||
return await report_service.get_cash_flow_report(db, current_user.id, df, dt)
|
||||
|
||||
|
||||
@router.get("/category-breakdown", response_model=CategoryBreakdownReport)
|
||||
async def category_breakdown(
|
||||
date_from: date = Query(default=None),
|
||||
date_to: date = Query(default=None),
|
||||
type: str = Query(default="expense"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
today = date.today()
|
||||
df = date_from or date(today.year, today.month, 1)
|
||||
dt = date_to or today
|
||||
return await report_service.get_category_breakdown(db, current_user.id, df, dt, type)
|
||||
|
||||
|
||||
@router.get("/budget-vs-actual", response_model=BudgetVsActualReport)
|
||||
async def budget_vs_actual(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await report_service.get_budget_vs_actual(db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/spending-trends", response_model=SpendingTrendsReport)
|
||||
async def spending_trends(
|
||||
months: int = Query(default=6, ge=1, le=24),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return await report_service.get_spending_trends(db, current_user.id, months)
|
||||
332
backend/app/api/v1/transactions.py
Normal file
332
backend/app/api/v1/transactions.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
import csv
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.audit import write_audit
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.schemas.transaction import TransactionCreate, TransactionFilter, TransactionUpdate
|
||||
from app.services.transaction_service import (
|
||||
TransactionError,
|
||||
create_transaction,
|
||||
delete_transaction,
|
||||
get_transaction,
|
||||
import_csv,
|
||||
list_transactions,
|
||||
update_transaction,
|
||||
_to_response,
|
||||
)
|
||||
|
||||
MAX_IMPORT_FILE_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
MAX_IMPORT_ROWS = 50_000
|
||||
|
||||
ALLOWED_MIME_TYPES = {
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
}
|
||||
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".pdf"}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_transactions(
|
||||
account_id: uuid.UUID | None = None,
|
||||
category_id: uuid.UUID | None = None,
|
||||
type: str | None = None,
|
||||
status: str | None = None,
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
search: str | None = None,
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
from datetime import date
|
||||
filters = TransactionFilter(
|
||||
account_id=account_id,
|
||||
category_id=category_id,
|
||||
type=type,
|
||||
status=status,
|
||||
date_from=date.fromisoformat(date_from) if date_from else None,
|
||||
date_to=date.fromisoformat(date_to) if date_to else None,
|
||||
search=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return await list_transactions(db, user.id, filters)
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
async def create(
|
||||
body: TransactionCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
result = await create_transaction(db, user.id, body, user.base_currency)
|
||||
await write_audit(db, user_id=user.id, action="transaction_create")
|
||||
await db.commit()
|
||||
return result
|
||||
except TransactionError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
|
||||
@router.get("/{txn_id}")
|
||||
async def get_one(
|
||||
txn_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
txn = await get_transaction(db, txn_id, user.id)
|
||||
return _to_response(txn)
|
||||
except TransactionError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
|
||||
@router.put("/{txn_id}")
|
||||
async def update(
|
||||
txn_id: uuid.UUID,
|
||||
body: TransactionUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
result = await update_transaction(db, txn_id, user.id, body, user.base_currency)
|
||||
await write_audit(db, user_id=user.id, action="transaction_update", resource_type="transaction", resource_id=txn_id)
|
||||
await db.commit()
|
||||
return result
|
||||
except TransactionError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
|
||||
@router.delete("/{txn_id}", status_code=204)
|
||||
async def delete(
|
||||
txn_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
await delete_transaction(db, txn_id, user.id)
|
||||
await write_audit(db, user_id=user.id, action="transaction_delete", resource_type="transaction", resource_id=txn_id)
|
||||
await db.commit()
|
||||
except TransactionError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
|
||||
@router.post("/{txn_id}/attachments")
|
||||
async def upload_attachment(
|
||||
txn_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
settings = get_settings()
|
||||
|
||||
# Validate extension
|
||||
filename = file.filename or "upload"
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(status_code=400, detail="Unsupported file type. Allowed: JPG, PNG, WebP, PDF")
|
||||
|
||||
# Verify transaction ownership
|
||||
try:
|
||||
txn = await get_transaction(db, txn_id, user.id)
|
||||
except TransactionError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
|
||||
current_refs: list = txn.get("attachment_refs", []) if isinstance(txn, dict) else []
|
||||
# Fetch raw model for JSONB mutation
|
||||
from sqlalchemy import select
|
||||
from app.db.models.transaction import Transaction as TxnModel
|
||||
result = await db.execute(
|
||||
select(TxnModel).where(TxnModel.id == txn_id, TxnModel.user_id == user.id)
|
||||
)
|
||||
txn_row = result.scalar_one_or_none()
|
||||
if not txn_row:
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
current_refs = list(txn_row.attachment_refs or [])
|
||||
if len(current_refs) >= settings.max_attachments_per_txn:
|
||||
raise HTTPException(status_code=400, detail=f"Maximum {settings.max_attachments_per_txn} attachments per transaction")
|
||||
|
||||
# Read and size-check
|
||||
content = await file.read(settings.max_attachment_bytes + 1)
|
||||
if len(content) > settings.max_attachment_bytes:
|
||||
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
|
||||
|
||||
# Sniff MIME from content
|
||||
import magic # python-magic
|
||||
detected_mime = magic.from_buffer(content[:2048], mime=True)
|
||||
if detected_mime not in ALLOWED_MIME_TYPES:
|
||||
raise HTTPException(status_code=400, detail="File content does not match an allowed type (JPEG, PNG, WebP, PDF)")
|
||||
|
||||
# Store file
|
||||
attachment_id = str(uuid.uuid4())
|
||||
user_upload_dir = Path(settings.upload_dir) / str(user.id)
|
||||
user_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
stored_name = f"{attachment_id}{ext}"
|
||||
stored_path = user_upload_dir / stored_name
|
||||
stored_path.write_bytes(content)
|
||||
|
||||
# Update attachment_refs
|
||||
ref = {
|
||||
"id": attachment_id,
|
||||
"filename": filename,
|
||||
"mime_type": detected_mime,
|
||||
"size": len(content),
|
||||
"stored_name": stored_name,
|
||||
}
|
||||
from sqlalchemy import update as sql_update
|
||||
import copy
|
||||
new_refs = copy.copy(current_refs)
|
||||
new_refs.append(ref)
|
||||
await db.execute(
|
||||
sql_update(TxnModel)
|
||||
.where(TxnModel.id == txn_id)
|
||||
.values(attachment_refs=new_refs)
|
||||
)
|
||||
await write_audit(db, user_id=user.id, action="transaction_update", resource_type="transaction", resource_id=txn_id, metadata={"attachment_added": attachment_id})
|
||||
await db.commit()
|
||||
return ref
|
||||
|
||||
|
||||
@router.get("/{txn_id}/attachments/{attachment_id}")
|
||||
async def download_attachment(
|
||||
txn_id: uuid.UUID,
|
||||
attachment_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
settings = get_settings()
|
||||
|
||||
from sqlalchemy import select
|
||||
from app.db.models.transaction import Transaction as TxnModel
|
||||
result = await db.execute(
|
||||
select(TxnModel).where(TxnModel.id == txn_id, TxnModel.user_id == user.id)
|
||||
)
|
||||
txn_row = result.scalar_one_or_none()
|
||||
if not txn_row:
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
ref = next((r for r in (txn_row.attachment_refs or []) if r["id"] == attachment_id), None)
|
||||
if not ref:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
path = Path(settings.upload_dir) / str(user.id) / ref["stored_name"]
|
||||
if not path.exists():
|
||||
raise HTTPException(status_code=404, detail="Attachment file missing")
|
||||
|
||||
return FileResponse(
|
||||
path=str(path),
|
||||
media_type=ref["mime_type"],
|
||||
filename=ref["filename"],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{txn_id}/attachments/{attachment_id}", status_code=204)
|
||||
async def delete_attachment(
|
||||
txn_id: uuid.UUID,
|
||||
attachment_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
settings = get_settings()
|
||||
|
||||
from sqlalchemy import select, update as sql_update
|
||||
from app.db.models.transaction import Transaction as TxnModel
|
||||
result = await db.execute(
|
||||
select(TxnModel).where(TxnModel.id == txn_id, TxnModel.user_id == user.id)
|
||||
)
|
||||
txn_row = result.scalar_one_or_none()
|
||||
if not txn_row:
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
refs = list(txn_row.attachment_refs or [])
|
||||
ref = next((r for r in refs if r["id"] == attachment_id), None)
|
||||
if not ref:
|
||||
raise HTTPException(status_code=404, detail="Attachment not found")
|
||||
|
||||
# Delete file
|
||||
path = Path(settings.upload_dir) / str(user.id) / ref["stored_name"]
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
new_refs = [r for r in refs if r["id"] != attachment_id]
|
||||
await db.execute(
|
||||
sql_update(TxnModel)
|
||||
.where(TxnModel.id == txn_id)
|
||||
.values(attachment_refs=new_refs)
|
||||
)
|
||||
await write_audit(db, user_id=user.id, action="transaction_update", resource_type="transaction", resource_id=txn_id, metadata={"attachment_deleted": attachment_id})
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
async def import_transactions(
|
||||
file: UploadFile = File(...),
|
||||
account_id: uuid.UUID = Form(...),
|
||||
date_col: str = Form(default="date"),
|
||||
description_col: str = Form(default="description"),
|
||||
amount_col: str = Form(default="amount"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise HTTPException(status_code=400, detail="Only CSV files are supported")
|
||||
|
||||
content = await file.read(MAX_IMPORT_FILE_BYTES + 1)
|
||||
if len(content) > MAX_IMPORT_FILE_BYTES:
|
||||
raise HTTPException(status_code=413, detail="File too large (max 10 MB)")
|
||||
try:
|
||||
text = content.decode("utf-8-sig") # handle BOM
|
||||
except UnicodeDecodeError:
|
||||
text = content.decode("latin-1")
|
||||
|
||||
reader = csv.DictReader(io.StringIO(text))
|
||||
rows = []
|
||||
for row in reader:
|
||||
if len(rows) >= MAX_IMPORT_ROWS:
|
||||
raise HTTPException(status_code=400, detail=f"File contains too many rows (max {MAX_IMPORT_ROWS:,})")
|
||||
mapped = {}
|
||||
# Flexible column mapping
|
||||
for key, col in [("date", date_col), ("description", description_col), ("amount", amount_col)]:
|
||||
val = row.get(col) or row.get(col.lower()) or row.get(col.upper())
|
||||
if val is not None:
|
||||
mapped[key] = val.strip()
|
||||
if "date" in mapped and "amount" in mapped:
|
||||
mapped.setdefault("description", "Imported transaction")
|
||||
rows.append(mapped)
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(status_code=400, detail="No valid rows found. Check column names.")
|
||||
|
||||
result = await import_csv(db, user.id, account_id, rows, user.base_currency)
|
||||
await write_audit(db, user_id=user.id, action="import_data", metadata=result)
|
||||
await db.commit()
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/import/template")
|
||||
async def import_template():
|
||||
from fastapi.responses import Response
|
||||
csv_content = "date,description,amount,merchant,notes\n2026-01-15,Tesco Groceries,-45.67,Tesco,\n2026-01-14,Salary,2500.00,Employer,January salary\n"
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=import_template.csv"},
|
||||
)
|
||||
126
backend/app/api/v1/users.py
Normal file
126
backend/app/api/v1/users.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import csv
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.audit import write_audit
|
||||
from app.core.security import hash_password, verify_password
|
||||
from app.dependencies import get_current_user, get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(user=Depends(get_current_user)):
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"display_name": user.display_name,
|
||||
"base_currency": user.base_currency,
|
||||
"theme": user.theme,
|
||||
"locale": user.locale,
|
||||
"totp_enabled": user.totp_enabled,
|
||||
"last_login_at": user.last_login_at,
|
||||
"created_at": user.created_at,
|
||||
}
|
||||
|
||||
|
||||
class PasswordChangeRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=10)
|
||||
|
||||
|
||||
@router.post("/me/password", status_code=200)
|
||||
async def change_password(
|
||||
body: PasswordChangeRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
if not verify_password(body.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
user.password_hash = hash_password(body.new_password)
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
await write_audit(db, user_id=user.id, action="password_change")
|
||||
await db.commit()
|
||||
return {"message": "Password updated successfully"}
|
||||
|
||||
|
||||
class ProfileUpdateRequest(BaseModel):
|
||||
display_name: str | None = Field(default=None, max_length=100)
|
||||
base_currency: str | None = Field(default=None, min_length=3, max_length=10)
|
||||
|
||||
|
||||
@router.put("/me", status_code=200)
|
||||
async def update_profile(
|
||||
body: ProfileUpdateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
if body.display_name is not None:
|
||||
user.display_name = body.display_name
|
||||
if body.base_currency is not None:
|
||||
user.base_currency = body.base_currency.upper()
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
return {"message": "Profile updated"}
|
||||
|
||||
|
||||
@router.get("/me/export")
|
||||
async def export_data(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
from app.db.models.transaction import Transaction
|
||||
from app.db.models.account import Account
|
||||
from app.db.models.category import Category
|
||||
from app.core.security import decrypt_field
|
||||
|
||||
result = await db.execute(
|
||||
select(Transaction, Account, Category)
|
||||
.join(Account, Account.id == Transaction.account_id)
|
||||
.outerjoin(Category, Category.id == Transaction.category_id)
|
||||
.where(
|
||||
Transaction.user_id == user.id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(Transaction.date.desc())
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow([
|
||||
"date", "description", "merchant", "amount", "currency",
|
||||
"type", "status", "category", "account", "notes", "tags",
|
||||
])
|
||||
|
||||
for txn, account, category in rows:
|
||||
writer.writerow([
|
||||
txn.date.isoformat(),
|
||||
decrypt_field(txn.description_enc) or "",
|
||||
decrypt_field(txn.merchant_enc) if txn.merchant_enc else "",
|
||||
str(txn.amount),
|
||||
txn.currency,
|
||||
txn.type,
|
||||
txn.status,
|
||||
category.name if category else "",
|
||||
decrypt_field(account.name_enc) or "",
|
||||
decrypt_field(txn.notes_enc) if txn.notes_enc else "",
|
||||
",".join(txn.tags or []),
|
||||
])
|
||||
|
||||
output.seek(0)
|
||||
filename = f"transactions_{datetime.now(timezone.utc).strftime('%Y%m%d')}.csv"
|
||||
await write_audit(db, user_id=user.id, action="data_export")
|
||||
await db.commit()
|
||||
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
54
backend/app/config.py
Normal file
54
backend/app/config.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
database_url: str = "postgresql+asyncpg://finance_app:password@postgres:5432/financedb"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
encryption_key: str # 32-byte hex string
|
||||
backup_passphrase: str = ""
|
||||
|
||||
environment: str = "production"
|
||||
allow_registration: bool = False
|
||||
base_currency: str = "GBP"
|
||||
|
||||
# JWT — keys read from /run/secrets/ at runtime
|
||||
jwt_private_key_file: str = "/run/secrets/jwt_private.pem"
|
||||
jwt_public_key_file: str = "/run/secrets/jwt_public.pem"
|
||||
jwt_algorithm: str = "RS256"
|
||||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 7
|
||||
|
||||
# Security
|
||||
csrf_token_expire_hours: int = 24
|
||||
max_login_attempts: int = 5
|
||||
lockout_base_seconds: int = 1800 # 30 min, doubles each time
|
||||
|
||||
# Rate limits (requests per minute)
|
||||
rate_limit_auth: int = 10
|
||||
rate_limit_api: int = 300
|
||||
rate_limit_predictions: int = 20
|
||||
|
||||
# File uploads
|
||||
upload_dir: str = "/app/uploads"
|
||||
max_attachment_bytes: int = 10 * 1024 * 1024 # 10 MB
|
||||
max_attachments_per_txn: int = 10
|
||||
|
||||
# Background jobs
|
||||
price_sync_interval_minutes: int = 15
|
||||
fx_sync_interval_minutes: int = 60
|
||||
snapshot_hour: int = 2 # 2 AM daily
|
||||
backup_hour: int = 3 # 3 AM daily
|
||||
ml_retrain_day: str = "sun" # weekly on Sunday
|
||||
|
||||
@property
|
||||
def is_development(self) -> bool:
|
||||
return self.environment == "development"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
38
backend/app/core/audit.py
Normal file
38
backend/app/core/audit.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""
|
||||
Append-only audit log writer.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def write_audit(
|
||||
db: "AsyncSession",
|
||||
*,
|
||||
user_id: uuid.UUID | None,
|
||||
action: str,
|
||||
resource_type: str | None = None,
|
||||
resource_id: uuid.UUID | None = None,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
from app.db.models.audit_log import AuditLog
|
||||
log = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
meta=metadata or {},
|
||||
success=success,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(log)
|
||||
# Note: caller is responsible for committing
|
||||
24
backend/app/core/encryption.py
Normal file
24
backend/app/core/encryption.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""
|
||||
Helpers for re-encrypting all sensitive DB fields during key rotation.
|
||||
"""
|
||||
from app.core.security import decrypt_field, encrypt_field
|
||||
|
||||
|
||||
def reencrypt(data: bytes, old_key_hex: str, new_key_hex: str) -> bytes:
|
||||
"""Re-encrypt a bytea field from old key to new key."""
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
import os
|
||||
|
||||
old_key = bytes.fromhex(old_key_hex)
|
||||
new_key = bytes.fromhex(new_key_hex)
|
||||
|
||||
# Decrypt with old key
|
||||
iv = data[:12]
|
||||
ciphertext_with_tag = data[12:]
|
||||
aesgcm_old = AESGCM(old_key)
|
||||
plaintext = aesgcm_old.decrypt(iv, ciphertext_with_tag, None)
|
||||
|
||||
# Encrypt with new key
|
||||
new_iv = os.urandom(12)
|
||||
aesgcm_new = AESGCM(new_key)
|
||||
return new_iv + aesgcm_new.encrypt(new_iv, plaintext, None)
|
||||
144
backend/app/core/key_rotation.py
Normal file
144
backend/app/core/key_rotation.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
AES-256-GCM key rotation: decrypt all encrypted fields with OLD key, re-encrypt with NEW key.
|
||||
|
||||
Run while the application is STOPPED:
|
||||
docker compose exec \
|
||||
-e ENCRYPTION_KEY="$OLD_ENCRYPTION_KEY" \
|
||||
-e NEW_ENCRYPTION_KEY="$NEW_ENCRYPTION_KEY" \
|
||||
backend python -m app.core.key_rotation
|
||||
|
||||
On success, update ENCRYPTION_KEY in .env to the new value and restart.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import psycopg2
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="[rotate] %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_cipher(key_hex: str) -> AESGCM:
|
||||
key = bytes.fromhex(key_hex)
|
||||
if len(key) != 32:
|
||||
raise ValueError("Key must be 32 bytes (64 hex chars)")
|
||||
return AESGCM(key)
|
||||
|
||||
|
||||
def _decrypt(cipher: AESGCM, data: bytes) -> bytes:
|
||||
"""Return plaintext bytes given IV(12)||ciphertext+tag."""
|
||||
if not data:
|
||||
return b""
|
||||
return cipher.decrypt(data[:12], data[12:], None)
|
||||
|
||||
|
||||
def _encrypt(cipher: AESGCM, plaintext: bytes) -> bytes:
|
||||
"""Encrypt plaintext bytes → IV(12)||ciphertext+tag."""
|
||||
if not plaintext:
|
||||
return b""
|
||||
iv = os.urandom(12)
|
||||
return iv + cipher.encrypt(iv, plaintext, None)
|
||||
|
||||
|
||||
def _reencrypt(old: AESGCM, new: AESGCM, data: bytes | None) -> bytes | None:
|
||||
if not data:
|
||||
return data
|
||||
plaintext = _decrypt(old, data)
|
||||
return _encrypt(new, plaintext)
|
||||
|
||||
|
||||
def _reencrypt_hex(old: AESGCM, new: AESGCM, hex_str: str | None) -> str | None:
|
||||
"""For fields stored as hex strings (e.g. totp_secret_enc)."""
|
||||
if not hex_str:
|
||||
return hex_str
|
||||
data = bytes.fromhex(hex_str)
|
||||
plaintext = _decrypt(old, data)
|
||||
return _encrypt(new, plaintext).hex()
|
||||
|
||||
|
||||
def rotate(db_url: str, old_key_hex: str, new_key_hex: str) -> None:
|
||||
old = _make_cipher(old_key_hex)
|
||||
new = _make_cipher(new_key_hex)
|
||||
|
||||
conn = psycopg2.connect(db_url)
|
||||
conn.autocommit = False
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
# ------------------------------------------------------------------ accounts
|
||||
cur.execute("SELECT id, name, institution, notes FROM accounts WHERE deleted_at IS NULL")
|
||||
rows = cur.fetchall()
|
||||
log.info(f"Rotating {len(rows)} account row(s)…")
|
||||
for row_id, name, institution, notes in rows:
|
||||
cur.execute(
|
||||
"UPDATE accounts SET name=%s, institution=%s, notes=%s WHERE id=%s",
|
||||
(
|
||||
_reencrypt(old, new, bytes(name) if name else None),
|
||||
_reencrypt(old, new, bytes(institution) if institution else None),
|
||||
_reencrypt(old, new, bytes(notes) if notes else None),
|
||||
row_id,
|
||||
),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------- transactions
|
||||
cur.execute(
|
||||
"SELECT id, description, merchant, notes FROM transactions WHERE deleted_at IS NULL"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
log.info(f"Rotating {len(rows)} transaction row(s)…")
|
||||
for row_id, description, merchant, notes in rows:
|
||||
cur.execute(
|
||||
"UPDATE transactions SET description=%s, merchant=%s, notes=%s WHERE id=%s",
|
||||
(
|
||||
_reencrypt(old, new, bytes(description) if description else None),
|
||||
_reencrypt(old, new, bytes(merchant) if merchant else None),
|
||||
_reencrypt(old, new, bytes(notes) if notes else None),
|
||||
row_id,
|
||||
),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------- users
|
||||
cur.execute("SELECT id, totp_secret FROM users WHERE deleted_at IS NULL")
|
||||
rows = cur.fetchall()
|
||||
log.info(f"Rotating {len(rows)} user row(s)…")
|
||||
for row_id, totp_secret in rows:
|
||||
cur.execute(
|
||||
"UPDATE users SET totp_secret=%s WHERE id=%s",
|
||||
(_reencrypt_hex(old, new, totp_secret), row_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
log.info("Key rotation complete — all fields re-encrypted.")
|
||||
log.info("Now update ENCRYPTION_KEY in .env and restart the application.")
|
||||
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
log.exception("Rotation FAILED — rolled back, no data changed.")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
old_key = os.environ.get("ENCRYPTION_KEY", "")
|
||||
new_key = os.environ.get("NEW_ENCRYPTION_KEY", "")
|
||||
db_url = os.environ.get("DATABASE_URL", "").replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
if not old_key:
|
||||
log.error("ENCRYPTION_KEY (current/old key) is not set")
|
||||
sys.exit(1)
|
||||
if not new_key:
|
||||
log.error("NEW_ENCRYPTION_KEY is not set")
|
||||
sys.exit(1)
|
||||
if not db_url:
|
||||
log.error("DATABASE_URL is not set")
|
||||
sys.exit(1)
|
||||
if old_key == new_key:
|
||||
log.error("NEW_ENCRYPTION_KEY is the same as ENCRYPTION_KEY — nothing to do")
|
||||
sys.exit(1)
|
||||
|
||||
rotate(db_url, old_key, new_key)
|
||||
81
backend/app/core/middleware.py
Normal file
81
backend/app/core/middleware.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Security middleware: headers, CSRF double-submit, request ID, RLS user context.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||||
|
||||
SECURITY_HEADERS = {
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"Permissions-Policy": "camera=(), microphone=(), geolocation=()",
|
||||
"Cross-Origin-Opener-Policy": "same-origin",
|
||||
"Cross-Origin-Resource-Policy": "same-origin",
|
||||
"Strict-Transport-Security": "max-age=63072000; includeSubDomains",
|
||||
"Content-Security-Policy": (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data:; "
|
||||
"connect-src 'self'; "
|
||||
"form-action 'self'; "
|
||||
"frame-ancestors 'none'"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
for header, value in SECURITY_HEADERS.items():
|
||||
response.headers[header] = value
|
||||
response.headers["X-Request-ID"] = str(uuid.uuid4())
|
||||
return response
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Double-submit cookie CSRF protection for mutating requests."""
|
||||
|
||||
EXEMPT_PATHS = {"/api/v1/auth/login", "/api/v1/auth/refresh", "/api/v1/auth/register", "/health"}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Always set the csrf_token cookie if it doesn't exist yet
|
||||
existing_csrf = request.cookies.get("csrf_token")
|
||||
|
||||
if request.method in SAFE_METHODS:
|
||||
response: Response = await call_next(request)
|
||||
if not existing_csrf:
|
||||
token = str(uuid.uuid4())
|
||||
response.set_cookie(
|
||||
"csrf_token", token,
|
||||
httponly=False, # must be readable by JS
|
||||
samesite="lax",
|
||||
secure=False, # set True if TLS is terminated at this service
|
||||
)
|
||||
return response
|
||||
|
||||
if request.url.path in self.EXEMPT_PATHS:
|
||||
response = await call_next(request)
|
||||
if not existing_csrf:
|
||||
token = str(uuid.uuid4())
|
||||
response.set_cookie("csrf_token", token, httponly=False, samesite="lax", secure=False)
|
||||
return response
|
||||
|
||||
if request.url.path in {"/api/v1/auth/login", "/api/v1/auth/login/totp"}:
|
||||
return await call_next(request)
|
||||
|
||||
cookie_token = existing_csrf
|
||||
header_token = request.headers.get("X-CSRF-Token")
|
||||
|
||||
if not cookie_token or not header_token or cookie_token != header_token:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing or invalid"},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
28
backend/app/core/rate_limiter.py
Normal file
28
backend/app/core/rate_limiter.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
Redis sliding window rate limiter.
|
||||
"""
|
||||
import time
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
async def is_rate_limited(
|
||||
redis: Redis,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int = 60,
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Returns (is_limited, requests_remaining).
|
||||
Uses a sorted set with timestamps as scores for sliding window.
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - window_seconds
|
||||
pipe = redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
pipe.zadd(key, {str(now): now})
|
||||
pipe.zcard(key)
|
||||
pipe.expire(key, window_seconds + 1)
|
||||
results = await pipe.execute()
|
||||
count = results[2]
|
||||
remaining = max(0, limit - count)
|
||||
return count > limit, remaining
|
||||
197
backend/app/core/security.py
Normal file
197
backend/app/core/security.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""
|
||||
Cryptographic primitives: Argon2id password hashing, RS256 JWT, AES-256-GCM field encryption, TOTP.
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyotp
|
||||
import qrcode
|
||||
import qrcode.image.svg
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError, VerificationError, InvalidHashError
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
# Argon2id — OWASP recommended parameters
|
||||
_ph = PasswordHasher(
|
||||
time_cost=3,
|
||||
memory_cost=65536,
|
||||
parallelism=4,
|
||||
hash_len=32,
|
||||
salt_len=16,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password hashing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return _ph.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, hashed: str) -> bool:
|
||||
try:
|
||||
return _ph.verify(hashed, password)
|
||||
except (VerifyMismatchError, VerificationError, InvalidHashError):
|
||||
return False
|
||||
|
||||
|
||||
def password_needs_rehash(hashed: str) -> bool:
|
||||
return _ph.check_needs_rehash(hashed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT (RS256)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_private_key() -> str:
|
||||
settings = get_settings()
|
||||
return Path(settings.jwt_private_key_file).read_text()
|
||||
|
||||
|
||||
def _load_public_key() -> str:
|
||||
settings = get_settings()
|
||||
return Path(settings.jwt_public_key_file).read_text()
|
||||
|
||||
|
||||
def create_access_token(subject: str, extra: dict[str, Any] | None = None) -> str:
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": subject,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
|
||||
"type": "access",
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
return jwt.encode(payload, _load_private_key(), algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def create_refresh_token(subject: str) -> str:
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": subject,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(days=settings.refresh_token_expire_days),
|
||||
"type": "refresh",
|
||||
"jti": secrets.token_hex(16),
|
||||
}
|
||||
return jwt.encode(payload, _load_private_key(), algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_token(token: str, token_type: str = "access") -> dict[str, Any]:
|
||||
settings = get_settings()
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
_load_public_key(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
options={"verify_exp": True},
|
||||
)
|
||||
if payload.get("type") != token_type:
|
||||
raise JWTError("Invalid token type")
|
||||
return payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AES-256-GCM field encryption
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_aes_key() -> bytes:
|
||||
"""Derive 32-byte key from hex ENCRYPTION_KEY env var."""
|
||||
settings = get_settings()
|
||||
key_hex = settings.encryption_key
|
||||
key = bytes.fromhex(key_hex)
|
||||
if len(key) != 32:
|
||||
raise ValueError("ENCRYPTION_KEY must be a 32-byte hex string (64 hex chars)")
|
||||
return key
|
||||
|
||||
|
||||
def encrypt_field(plaintext: str) -> bytes:
|
||||
"""Encrypt a string field. Returns IV(12) || ciphertext || tag(16) as bytes."""
|
||||
if not plaintext:
|
||||
return b""
|
||||
key = _get_aes_key()
|
||||
iv = os.urandom(12)
|
||||
aesgcm = AESGCM(key)
|
||||
ciphertext_with_tag = aesgcm.encrypt(iv, plaintext.encode(), None)
|
||||
return iv + ciphertext_with_tag
|
||||
|
||||
|
||||
def decrypt_field(data: bytes) -> str:
|
||||
"""Decrypt bytes produced by encrypt_field."""
|
||||
if not data:
|
||||
return ""
|
||||
key = _get_aes_key()
|
||||
iv = data[:12]
|
||||
ciphertext_with_tag = data[12:]
|
||||
aesgcm = AESGCM(key)
|
||||
return aesgcm.decrypt(iv, ciphertext_with_tag, None).decode()
|
||||
|
||||
|
||||
def encrypt_field_b64(plaintext: str) -> str:
|
||||
"""Convenience: encrypt and return base64 string (for JSON/text contexts)."""
|
||||
return base64.b64encode(encrypt_field(plaintext)).decode()
|
||||
|
||||
|
||||
def decrypt_field_b64(data: str) -> str:
|
||||
return decrypt_field(base64.b64decode(data))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TOTP (RFC 6238)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_totp_secret() -> str:
|
||||
return pyotp.random_base32()
|
||||
|
||||
|
||||
def get_totp_uri(secret: str, email: str) -> str:
|
||||
return pyotp.totp.TOTP(secret).provisioning_uri(
|
||||
name=email, issuer_name="Finance Tracker"
|
||||
)
|
||||
|
||||
|
||||
def generate_totp_qr_png(secret: str, email: str) -> bytes:
|
||||
uri = get_totp_uri(secret, email)
|
||||
img = qrcode.make(uri)
|
||||
from io import BytesIO
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def verify_totp(secret: str, code: str) -> bool:
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(code, valid_window=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSRF token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
return secrets.token_hex(32)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Misc helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_backup_codes(count: int = 8) -> list[str]:
|
||||
"""Generate one-time backup codes."""
|
||||
return [secrets.token_hex(4).upper() + "-" + secrets.token_hex(4).upper() for _ in range(count)]
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""SHA-256 hash of a bearer token for DB storage."""
|
||||
import hashlib
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
0
backend/app/db/__init__.py
Normal file
0
backend/app/db/__init__.py
Normal file
28
backend/app/db/base.py
Normal file
28
backend/app/db/base.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def create_engine():
|
||||
settings = get_settings()
|
||||
return create_async_engine(
|
||||
settings.database_url,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.is_development,
|
||||
)
|
||||
|
||||
|
||||
def create_session_factory(engine):
|
||||
return async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
19
backend/app/db/models/__init__.py
Normal file
19
backend/app/db/models/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from app.db.models.user import User
|
||||
from app.db.models.session import Session
|
||||
from app.db.models.account import Account
|
||||
from app.db.models.category import Category
|
||||
from app.db.models.transaction import Transaction
|
||||
from app.db.models.budget import Budget
|
||||
from app.db.models.asset import Asset
|
||||
from app.db.models.asset_price import AssetPrice
|
||||
from app.db.models.investment_holding import InvestmentHolding
|
||||
from app.db.models.investment_transaction import InvestmentTransaction
|
||||
from app.db.models.currency import Currency, ExchangeRate
|
||||
from app.db.models.net_worth_snapshot import NetWorthSnapshot
|
||||
from app.db.models.audit_log import AuditLog
|
||||
|
||||
__all__ = [
|
||||
"User", "Session", "Account", "Category", "Transaction", "Budget",
|
||||
"Asset", "AssetPrice", "InvestmentHolding", "InvestmentTransaction",
|
||||
"Currency", "ExchangeRate", "NetWorthSnapshot", "AuditLog",
|
||||
]
|
||||
36
backend/app/db/models/account.py
Normal file
36
backend/app/db/models/account.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, LargeBinary, Numeric, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Account(Base):
|
||||
__tablename__ = "accounts"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
name_enc: Mapped[bytes] = mapped_column("name", LargeBinary, nullable=False)
|
||||
institution_enc: Mapped[bytes | None] = mapped_column("institution", LargeBinary, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
current_balance: Mapped[Decimal] = mapped_column(Numeric(20, 8), default=0, nullable=False)
|
||||
credit_limit: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
|
||||
interest_rate: Mapped[Decimal | None] = mapped_column(Numeric(8, 4), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
include_in_net_worth: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
color: Mapped[str] = mapped_column(String(7), default="#6366f1", nullable=False)
|
||||
icon: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
notes_enc: Mapped[bytes | None] = mapped_column("notes", LargeBinary, nullable=True)
|
||||
meta: Mapped[dict] = mapped_column(JSONB, default=dict, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="accounts", lazy="noload") # type: ignore[name-defined]
|
||||
transactions: Mapped[list["Transaction"]] = relationship(foreign_keys="Transaction.account_id", back_populates="account", lazy="noload") # type: ignore[name-defined]
|
||||
holdings: Mapped[list["InvestmentHolding"]] = relationship(back_populates="account", lazy="noload") # type: ignore[name-defined]
|
||||
32
backend/app/db/models/asset.py
Normal file
32
backend/app/db/models/asset.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Numeric, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
symbol: Mapped[str] = mapped_column(Text, nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(30), nullable=False) # stock|etf|mutual_fund|bond|crypto|commodity|other
|
||||
currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
exchange: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
isin: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
data_source: Mapped[str] = mapped_column(String(30), default="yahoo_finance", nullable=False)
|
||||
data_source_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
last_price: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
|
||||
last_price_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
price_change_24h: Mapped[Decimal | None] = mapped_column(Numeric(10, 4), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
prices: Mapped[list["AssetPrice"]] = relationship(back_populates="asset", lazy="noload") # type: ignore[name-defined]
|
||||
holdings: Mapped[list["InvestmentHolding"]] = relationship(back_populates="asset", lazy="noload") # type: ignore[name-defined]
|
||||
25
backend/app/db/models/asset_price.py
Normal file
25
backend/app/db/models/asset_price.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Date, DateTime, ForeignKey, Numeric
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class AssetPrice(Base):
|
||||
__tablename__ = "asset_prices"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
asset_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
open: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
|
||||
high: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
|
||||
low: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
|
||||
close: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
volume: Mapped[Decimal | None] = mapped_column(Numeric(30, 8), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="prices", lazy="noload") # type: ignore[name-defined]
|
||||
23
backend/app/db/models/audit_log.py
Normal file
23
backend/app/db/models/audit_log.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import INET, JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True)
|
||||
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_type: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
resource_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(INET, nullable=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
meta: Mapped[dict] = mapped_column("metadata", JSONB, default=dict, nullable=False)
|
||||
success: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
30
backend/app/db/models/budget.py
Normal file
30
backend/app/db/models/budget.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, Numeric, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Budget(Base):
|
||||
__tablename__ = "budgets"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("categories.id"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
amount: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
period: Mapped[str] = mapped_column(String(20), nullable=False) # weekly|monthly|quarterly|yearly
|
||||
start_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
end_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
rollover: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
alert_threshold: Mapped[Decimal] = mapped_column(Numeric(5, 2), default=80, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
category: Mapped["Category"] = relationship(back_populates="budgets", lazy="noload") # type: ignore[name-defined]
|
||||
26
backend/app/db/models/category.py
Normal file
26
backend/app/db/models/category.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = "categories"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
parent_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("categories.id"), nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False) # income | expense | transfer
|
||||
icon: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
color: Mapped[str | None] = mapped_column(String(7), nullable=True)
|
||||
is_system: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
children: Mapped[list["Category"]] = relationship(lazy="noload")
|
||||
budgets: Mapped[list["Budget"]] = relationship(back_populates="category", lazy="noload") # type: ignore[name-defined]
|
||||
31
backend/app/db/models/currency.py
Normal file
31
backend/app/db/models/currency.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, Numeric, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Currency(Base):
|
||||
__tablename__ = "currencies"
|
||||
|
||||
code: Mapped[str] = mapped_column(String(10), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
symbol: Mapped[str] = mapped_column(String(5), nullable=False)
|
||||
is_crypto: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
decimal_places: Mapped[int] = mapped_column(Integer, default=2, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
class ExchangeRate(Base):
|
||||
__tablename__ = "exchange_rates"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
base_currency: Mapped[str] = mapped_column(String(10), nullable=False, index=True)
|
||||
quote_currency: Mapped[str] = mapped_column(String(10), nullable=False, index=True)
|
||||
rate: Mapped[Decimal] = mapped_column(Numeric(20, 10), nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
fetched_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
27
backend/app/db/models/investment_holding.py
Normal file
27
backend/app/db/models/investment_holding.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Numeric, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class InvestmentHolding(Base):
|
||||
__tablename__ = "investment_holdings"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id"), nullable=False)
|
||||
asset_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("assets.id"), nullable=False)
|
||||
quantity: Mapped[Decimal] = mapped_column(Numeric(30, 10), default=0, nullable=False)
|
||||
avg_cost_basis: Mapped[Decimal] = mapped_column(Numeric(20, 8), default=0, nullable=False)
|
||||
currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
account: Mapped["Account"] = relationship(back_populates="holdings", lazy="noload") # type: ignore[name-defined]
|
||||
asset: Mapped["Asset"] = relationship(back_populates="holdings", lazy="noload") # type: ignore[name-defined]
|
||||
investment_transactions: Mapped[list["InvestmentTransaction"]] = relationship(back_populates="holding", lazy="noload") # type: ignore[name-defined]
|
||||
29
backend/app/db/models/investment_transaction.py
Normal file
29
backend/app/db/models/investment_transaction.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Date, DateTime, ForeignKey, LargeBinary, Numeric, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class InvestmentTransaction(Base):
|
||||
__tablename__ = "investment_transactions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
holding_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("investment_holdings.id"), nullable=False, index=True)
|
||||
transaction_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("transactions.id"), nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False) # buy|sell|dividend|split|merger|transfer_in|transfer_out|fee
|
||||
quantity: Mapped[Decimal] = mapped_column(Numeric(30, 10), nullable=False)
|
||||
price: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
fees: Mapped[Decimal] = mapped_column(Numeric(20, 8), default=0, nullable=False)
|
||||
total_amount: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
notes_enc: Mapped[bytes | None] = mapped_column("notes", LargeBinary, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
holding: Mapped["InvestmentHolding"] = relationship(back_populates="investment_transactions", lazy="noload") # type: ignore[name-defined]
|
||||
23
backend/app/db/models/net_worth_snapshot.py
Normal file
23
backend/app/db/models/net_worth_snapshot.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Date, DateTime, ForeignKey, Numeric, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class NetWorthSnapshot(Base):
|
||||
__tablename__ = "net_worth_snapshots"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
total_assets: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
total_liabilities: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
net_worth: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
base_currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
breakdown: Mapped[dict] = mapped_column(JSONB, default=dict, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
24
backend/app/db/models/session.py
Normal file
24
backend/app/db/models/session.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Text
|
||||
from sqlalchemy.dialects.postgresql import INET, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Session(Base):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
token_hash: Mapped[str] = mapped_column(Text, unique=True, nullable=False, index=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(INET, nullable=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="sessions", lazy="noload") # type: ignore[name-defined]
|
||||
42
backend/app/db/models/transaction.py
Normal file
42
backend/app/db/models/transaction.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, LargeBinary, Numeric, String, Text
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id"), nullable=False, index=True)
|
||||
transfer_account_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id"), nullable=True)
|
||||
category_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("categories.id"), nullable=True, index=True)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False) # income|expense|transfer|investment
|
||||
status: Mapped[str] = mapped_column(String(20), default="cleared", nullable=False)
|
||||
amount: Mapped[Decimal] = mapped_column(Numeric(20, 8), nullable=False)
|
||||
amount_base: Mapped[Decimal | None] = mapped_column(Numeric(20, 8), nullable=True)
|
||||
currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
base_currency: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
exchange_rate: Mapped[Decimal | None] = mapped_column(Numeric(20, 10), nullable=True)
|
||||
date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
|
||||
description_enc: Mapped[bytes] = mapped_column("description", LargeBinary, nullable=False)
|
||||
merchant_enc: Mapped[bytes | None] = mapped_column("merchant", LargeBinary, nullable=True)
|
||||
notes_enc: Mapped[bytes | None] = mapped_column("notes", LargeBinary, nullable=True)
|
||||
tags: Mapped[list[str]] = mapped_column(ARRAY(Text), default=list, nullable=False)
|
||||
is_recurring: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
recurring_rule: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
attachment_refs: Mapped[list] = mapped_column(JSONB, default=list, nullable=False)
|
||||
import_hash: Mapped[str | None] = mapped_column(Text, nullable=True, index=True)
|
||||
meta: Mapped[dict] = mapped_column(JSONB, default=dict, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
account: Mapped["Account"] = relationship(foreign_keys=[account_id], back_populates="transactions", lazy="noload") # type: ignore[name-defined]
|
||||
category: Mapped["Category | None"] = relationship(lazy="noload") # type: ignore[name-defined]
|
||||
33
backend/app/db/models/user.py
Normal file
33
backend/app/db/models/user.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import INET, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email: Mapped[str] = mapped_column(Text, unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
totp_secret_enc: Mapped[bytes | None] = mapped_column("totp_secret", type_=String, nullable=True)
|
||||
totp_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
totp_backup_codes_enc: Mapped[str | None] = mapped_column("totp_backup_codes", Text, nullable=True)
|
||||
display_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
base_currency: Mapped[str] = mapped_column(String(10), default="GBP", nullable=False)
|
||||
theme: Mapped[str] = mapped_column(String(20), default="dark", nullable=False)
|
||||
locale: Mapped[str] = mapped_column(String(20), default="en-GB", nullable=False)
|
||||
failed_login_attempts: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
locked_until: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_login_ip: Mapped[str | None] = mapped_column(INET, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
accounts: Mapped[list["Account"]] = relationship(back_populates="user", lazy="noload") # type: ignore[name-defined]
|
||||
sessions: Mapped[list["Session"]] = relationship(back_populates="user", lazy="noload") # type: ignore[name-defined]
|
||||
92
backend/app/dependencies.py
Normal file
92
backend/app/dependencies.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
"""
|
||||
FastAPI dependency injection: DB session, Redis, current authenticated user.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, Request, status
|
||||
from jose import JWTError
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import decode_token, hash_token
|
||||
from app.db.models.session import Session
|
||||
from app.db.models.user import User
|
||||
|
||||
# These are set by main.py lifespan
|
||||
_session_factory = None
|
||||
_redis_client: Redis | None = None
|
||||
|
||||
|
||||
def set_session_factory(factory):
|
||||
global _session_factory
|
||||
_session_factory = factory
|
||||
|
||||
|
||||
def get_session_factory():
|
||||
return _session_factory
|
||||
|
||||
|
||||
def set_redis_client(client: Redis):
|
||||
global _redis_client
|
||||
_redis_client = client
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with _session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_redis() -> Redis:
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _extract_bearer(request: Request) -> str | None:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
return auth[7:]
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
token = _extract_bearer(request)
|
||||
if not token:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
payload = decode_token(token, token_type="access")
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
except (JWTError, ValueError, KeyError):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
|
||||
token_hash = hash_token(token)
|
||||
from datetime import datetime, timezone
|
||||
result = await db.execute(
|
||||
select(Session).where(
|
||||
Session.user_id == user_id,
|
||||
Session.token_hash == token_hash,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.expires_at > datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Session expired or revoked")
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.deleted_at.is_(None))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
|
||||
# Set RLS context so PostgreSQL RLS policies apply
|
||||
await db.execute(text(f"SET LOCAL app.current_user_id = '{user_id}'"))
|
||||
|
||||
return user
|
||||
88
backend/app/main.py
Normal file
88
backend/app/main.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""
|
||||
FastAPI application factory with lifespan management.
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from redis.asyncio import Redis, from_url
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.middleware import CSRFMiddleware, SecurityHeadersMiddleware
|
||||
from app.db.base import create_engine, create_session_factory
|
||||
from app.dependencies import set_redis_client, set_session_factory
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
settings = get_settings()
|
||||
|
||||
# Database
|
||||
engine = create_engine()
|
||||
session_factory = create_session_factory(engine)
|
||||
set_session_factory(session_factory)
|
||||
|
||||
# Redis
|
||||
redis: Redis = from_url(settings.redis_url, decode_responses=False)
|
||||
set_redis_client(redis)
|
||||
|
||||
# Seed system categories if needed
|
||||
from app.services.category_service import seed_system_categories
|
||||
async with session_factory() as db:
|
||||
await seed_system_categories(db)
|
||||
await db.commit()
|
||||
|
||||
# Background scheduler
|
||||
from app.workers.scheduler import start_scheduler, stop_scheduler
|
||||
await start_scheduler()
|
||||
|
||||
logger.info("startup_complete", env=settings.environment)
|
||||
yield
|
||||
|
||||
await stop_scheduler()
|
||||
await redis.aclose()
|
||||
await engine.dispose()
|
||||
logger.info("shutdown_complete")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="Finance Tracker",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.is_development else None,
|
||||
redoc_url="/redoc" if settings.is_development else None,
|
||||
openapi_url="/openapi.json" if settings.is_development else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS — only allow same origin in production
|
||||
origins = ["http://localhost:5173"] if settings.is_development else []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# Health check (no auth required)
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# API routers
|
||||
from app.api.router import router
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
0
backend/app/ml/__init__.py
Normal file
0
backend/app/ml/__init__.py
Normal file
119
backend/app/ml/feature_engineering.py
Normal file
119
backend/app/ml/feature_engineering.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def get_monthly_category_spending(db: AsyncSession, user_id: str) -> pd.DataFrame:
|
||||
result = await db.execute(text("""
|
||||
SELECT
|
||||
COALESCE(t.category_id::text, 'uncategorised') AS category_id,
|
||||
COALESCE(c.name, 'Uncategorised') AS category_name,
|
||||
DATE_TRUNC('month', t.date)::date AS ds,
|
||||
SUM(ABS(t.amount))::float AS y
|
||||
FROM transactions t
|
||||
LEFT JOIN categories c ON c.id = t.category_id
|
||||
WHERE t.user_id = CAST(:uid AS uuid)
|
||||
AND t.type = 'expense'
|
||||
AND t.deleted_at IS NULL
|
||||
AND t.status != 'void'
|
||||
GROUP BY t.category_id, c.name, DATE_TRUNC('month', t.date)
|
||||
ORDER BY ds ASC
|
||||
"""), {"uid": str(user_id)})
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=["category_id", "category_name", "ds", "y"])
|
||||
df = pd.DataFrame(rows, columns=["category_id", "category_name", "ds", "y"])
|
||||
df["ds"] = pd.to_datetime(df["ds"])
|
||||
df["y"] = df["y"].astype(float)
|
||||
return df
|
||||
|
||||
|
||||
async def get_monthly_net_worth(db: AsyncSession, user_id: str) -> pd.DataFrame:
|
||||
result = await db.execute(text("""
|
||||
SELECT date::text AS ds, net_worth::float AS y
|
||||
FROM net_worth_snapshots
|
||||
WHERE user_id = CAST(:uid AS uuid)
|
||||
ORDER BY date ASC
|
||||
"""), {"uid": str(user_id)})
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=["ds", "y"])
|
||||
df = pd.DataFrame(rows, columns=["ds", "y"])
|
||||
df["ds"] = pd.to_datetime(df["ds"])
|
||||
df["y"] = df["y"].astype(float)
|
||||
# Resample to monthly end, keeping last value
|
||||
df = df.set_index("ds").resample("ME").last().dropna().reset_index()
|
||||
df.columns = ["ds", "y"]
|
||||
return df
|
||||
|
||||
|
||||
async def get_current_month_spending(db: AsyncSession, user_id: str) -> pd.DataFrame:
|
||||
result = await db.execute(text("""
|
||||
SELECT
|
||||
COALESCE(t.category_id::text, 'uncategorised') AS category_id,
|
||||
COALESCE(c.name, 'Uncategorised') AS category_name,
|
||||
SUM(ABS(t.amount))::float AS spent
|
||||
FROM transactions t
|
||||
LEFT JOIN categories c ON c.id = t.category_id
|
||||
WHERE t.user_id = CAST(:uid AS uuid)
|
||||
AND t.type = 'expense'
|
||||
AND t.deleted_at IS NULL
|
||||
AND t.status != 'void'
|
||||
AND DATE_TRUNC('month', t.date) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
GROUP BY t.category_id, c.name
|
||||
"""), {"uid": str(user_id)})
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=["category_id", "category_name", "spent"])
|
||||
df = pd.DataFrame(rows, columns=["category_id", "category_name", "spent"])
|
||||
df["spent"] = df["spent"].astype(float)
|
||||
return df
|
||||
|
||||
|
||||
async def get_portfolio_monthly_returns(db: AsyncSession, user_id: str) -> pd.DataFrame:
|
||||
"""Monthly close prices for each asset in user's portfolio."""
|
||||
result = await db.execute(text("""
|
||||
SELECT
|
||||
a.symbol,
|
||||
DATE_TRUNC('month', ap.date)::date AS month,
|
||||
(ARRAY_AGG(ap.close ORDER BY ap.date DESC))[1]::float AS close
|
||||
FROM investment_holdings h
|
||||
JOIN assets a ON a.id = h.asset_id
|
||||
JOIN asset_prices ap ON ap.asset_id = h.asset_id
|
||||
WHERE h.user_id = CAST(:uid AS uuid)
|
||||
AND h.deleted_at IS NULL
|
||||
GROUP BY a.symbol, DATE_TRUNC('month', ap.date)
|
||||
ORDER BY a.symbol, month ASC
|
||||
"""), {"uid": str(user_id)})
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=["symbol", "month", "close"])
|
||||
df = pd.DataFrame(rows, columns=["symbol", "month", "close"])
|
||||
df["month"] = pd.to_datetime(df["month"])
|
||||
df["close"] = df["close"].astype(float)
|
||||
return df
|
||||
|
||||
|
||||
async def get_daily_cash_flow(db: AsyncSession, user_id: str, days: int = 90) -> pd.DataFrame:
|
||||
result = await db.execute(text("""
|
||||
SELECT
|
||||
t.date::date AS ds,
|
||||
SUM(CASE WHEN t.amount > 0 THEN t.amount ELSE 0 END)::float AS inflow,
|
||||
SUM(CASE WHEN t.amount < 0 THEN ABS(t.amount) ELSE 0 END)::float AS outflow
|
||||
FROM transactions t
|
||||
WHERE t.user_id = CAST(:uid AS uuid)
|
||||
AND t.deleted_at IS NULL
|
||||
AND t.status != 'void'
|
||||
AND t.type IN ('income', 'expense')
|
||||
AND t.date >= CURRENT_DATE - :days
|
||||
GROUP BY t.date
|
||||
ORDER BY t.date ASC
|
||||
"""), {"uid": str(user_id), "days": days})
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=["ds", "inflow", "outflow"])
|
||||
df = pd.DataFrame(rows, columns=["ds", "inflow", "outflow"])
|
||||
df["ds"] = pd.to_datetime(df["ds"])
|
||||
return df
|
||||
135
backend/app/ml/monte_carlo.py
Normal file
135
backend/app/ml/monte_carlo.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DEFAULT_MU = 0.07 / 12 # 7% annual expected return, monthly
|
||||
DEFAULT_SIGMA = 0.15 / (12 ** 0.5) # 15% annual vol, monthly
|
||||
DT = 1.0 / 12
|
||||
|
||||
|
||||
def _project_months(from_date: date, n: int) -> list[str]:
|
||||
d = from_date.replace(day=1)
|
||||
return [(d + relativedelta(months=i + 1)).strftime("%Y-%m") for i in range(n)]
|
||||
|
||||
|
||||
def run_monte_carlo(
|
||||
prices_df: pd.DataFrame,
|
||||
holdings: list[dict],
|
||||
years: int = 5,
|
||||
n_sims: int = 1000,
|
||||
annual_contribution: float = 0.0,
|
||||
) -> dict:
|
||||
"""
|
||||
prices_df: columns [symbol, month, close]
|
||||
holdings: [{"symbol": str, "quantity": float, "current_value": float}]
|
||||
Returns percentile paths and summary stats.
|
||||
"""
|
||||
n_months = years * 12
|
||||
today = date.today()
|
||||
future_dates = _project_months(today, n_months)
|
||||
monthly_contribution = annual_contribution / 12.0
|
||||
|
||||
symbols = [h["symbol"] for h in holdings]
|
||||
current_values = np.array([float(h.get("current_value") or 0) for h in holdings])
|
||||
total_value = float(current_values.sum())
|
||||
|
||||
if total_value <= 0:
|
||||
return {
|
||||
"dates": future_dates,
|
||||
"percentiles": {},
|
||||
"current_value": 0.0,
|
||||
"expected_value": 0.0,
|
||||
"probability_of_gain": 0.5,
|
||||
"insufficient_data": True,
|
||||
}
|
||||
|
||||
# Compute per-asset parameters from price history
|
||||
n_assets = len(symbols)
|
||||
mus = np.full(n_assets, DEFAULT_MU)
|
||||
sigmas = np.full(n_assets, DEFAULT_SIGMA)
|
||||
corr = np.eye(n_assets)
|
||||
|
||||
if not prices_df.empty:
|
||||
for i, sym in enumerate(symbols):
|
||||
sym_prices = prices_df[prices_df["symbol"] == sym].sort_values("month")
|
||||
if len(sym_prices) >= 3:
|
||||
closes = sym_prices["close"].values.astype(float)
|
||||
log_rets = np.diff(np.log(closes[closes > 0]))
|
||||
if len(log_rets) >= 2:
|
||||
mus[i] = float(np.mean(log_rets))
|
||||
sigmas[i] = float(np.std(log_rets))
|
||||
|
||||
# Build correlation matrix from overlapping return series
|
||||
if n_assets > 1:
|
||||
ret_series = {}
|
||||
for sym in symbols:
|
||||
sym_prices = prices_df[prices_df["symbol"] == sym].sort_values("month")
|
||||
if len(sym_prices) >= 3:
|
||||
closes = sym_prices["close"].values.astype(float)
|
||||
log_rets = np.diff(np.log(closes[closes > 0]))
|
||||
ret_series[sym] = log_rets
|
||||
|
||||
if len(ret_series) == n_assets:
|
||||
min_len = min(len(v) for v in ret_series.values())
|
||||
if min_len >= 3:
|
||||
matrix = np.array([v[-min_len:] for v in ret_series.values()])
|
||||
corr = np.corrcoef(matrix)
|
||||
corr = np.clip(corr, -0.99, 0.99)
|
||||
np.fill_diagonal(corr, 1.0)
|
||||
|
||||
# Covariance matrix and Cholesky decomposition
|
||||
cov = np.outer(sigmas, sigmas) * corr
|
||||
try:
|
||||
L = np.linalg.cholesky(cov)
|
||||
except np.linalg.LinAlgError:
|
||||
# Fall back to diagonal covariance
|
||||
L = np.diag(sigmas)
|
||||
|
||||
# Portfolio weights
|
||||
weights = current_values / total_value
|
||||
|
||||
# GBM simulation
|
||||
rng = np.random.default_rng(42)
|
||||
portfolio_paths = np.zeros((n_sims, n_months))
|
||||
|
||||
for sim in range(n_sims):
|
||||
asset_values = current_values.copy()
|
||||
for t in range(n_months):
|
||||
Z = rng.standard_normal(n_assets)
|
||||
corr_Z = L @ Z
|
||||
# GBM step for each asset
|
||||
asset_values = asset_values * np.exp(
|
||||
(mus - 0.5 * sigmas ** 2) * DT + sigmas * np.sqrt(DT) * corr_Z
|
||||
)
|
||||
port_val = float(asset_values.sum()) + monthly_contribution * (t + 1)
|
||||
portfolio_paths[sim, t] = max(0.0, port_val)
|
||||
|
||||
# Compute percentile paths
|
||||
pcts = {
|
||||
"p10": np.percentile(portfolio_paths, 10, axis=0),
|
||||
"p25": np.percentile(portfolio_paths, 25, axis=0),
|
||||
"p50": np.percentile(portfolio_paths, 50, axis=0),
|
||||
"p75": np.percentile(portfolio_paths, 75, axis=0),
|
||||
"p90": np.percentile(portfolio_paths, 90, axis=0),
|
||||
}
|
||||
|
||||
final_values = portfolio_paths[:, -1]
|
||||
prob_gain = float(np.mean(final_values > total_value))
|
||||
expected_value = float(np.median(final_values))
|
||||
|
||||
return {
|
||||
"dates": future_dates,
|
||||
"percentiles": {
|
||||
k: [{"date": d, "value": round(float(v), 2)} for d, v in zip(future_dates, arr)]
|
||||
for k, arr in pcts.items()
|
||||
},
|
||||
"current_value": round(total_value, 2),
|
||||
"expected_value": round(expected_value, 2),
|
||||
"probability_of_gain": round(prob_gain, 3),
|
||||
"insufficient_data": False,
|
||||
}
|
||||
102
backend/app/ml/net_worth_projection.py
Normal file
102
backend/app/ml/net_worth_projection.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from datetime import date
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def _project_months(from_date: date, n: int) -> list[str]:
|
||||
months = []
|
||||
d = from_date.replace(day=1)
|
||||
for i in range(1, n + 1):
|
||||
months.append((d + relativedelta(months=i)).strftime("%Y-%m"))
|
||||
return months
|
||||
|
||||
|
||||
def project_net_worth(df: pd.DataFrame, years: int = 5) -> dict:
|
||||
"""
|
||||
df columns: ds (monthly datetime), y (net_worth float)
|
||||
Returns history + 3-scenario projections.
|
||||
"""
|
||||
n_months = years * 12
|
||||
today = date.today()
|
||||
future_dates = _project_months(today, n_months)
|
||||
|
||||
history = [
|
||||
{"date": row["ds"].strftime("%Y-%m"), "value": round(float(row["y"]), 2)}
|
||||
for _, row in df.iterrows()
|
||||
]
|
||||
|
||||
if df.empty or len(df) < 2:
|
||||
# No data — return flat projection from 0
|
||||
last_val = float(df["y"].iloc[-1]) if not df.empty else 0.0
|
||||
flat = [{"date": d, "value": round(last_val, 2)} for d in future_dates]
|
||||
return {
|
||||
"history": history,
|
||||
"projections": {"conservative": flat, "base": flat, "optimistic": flat},
|
||||
"insufficient_data": True,
|
||||
}
|
||||
|
||||
try:
|
||||
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
||||
|
||||
values = df["y"].tolist()
|
||||
|
||||
if len(values) >= 12:
|
||||
model = ExponentialSmoothing(values, trend="add", seasonal="add", seasonal_periods=12)
|
||||
elif len(values) >= 4:
|
||||
model = ExponentialSmoothing(values, trend="add", seasonal=None)
|
||||
else:
|
||||
model = ExponentialSmoothing(values, trend="add", seasonal=None)
|
||||
|
||||
fit = model.fit(optimized=True, disp=False)
|
||||
base_fcast = fit.forecast(n_months)
|
||||
|
||||
# Estimate monthly trend from the fit
|
||||
monthly_trend = float(np.mean(np.diff(base_fcast[:12]))) if len(base_fcast) >= 12 else 0.0
|
||||
last_val = float(values[-1])
|
||||
|
||||
# Scale trends for scenarios
|
||||
def build_scenario(scale: float) -> list[dict]:
|
||||
pts = []
|
||||
v = last_val
|
||||
for i, d in enumerate(future_dates):
|
||||
v = float(base_fcast[i]) + (scale - 1.0) * monthly_trend * (i + 1)
|
||||
pts.append({"date": d, "value": round(v, 2)})
|
||||
return pts
|
||||
|
||||
return {
|
||||
"history": history,
|
||||
"projections": {
|
||||
"conservative": build_scenario(0.5),
|
||||
"base": [{"date": d, "value": round(float(v), 2)} for d, v in zip(future_dates, base_fcast)],
|
||||
"optimistic": build_scenario(1.5),
|
||||
},
|
||||
"insufficient_data": False,
|
||||
}
|
||||
|
||||
except Exception:
|
||||
# Fallback: linear trend from last 2 values
|
||||
trend = float(df["y"].iloc[-1]) - float(df["y"].iloc[-2])
|
||||
last_val = float(df["y"].iloc[-1])
|
||||
|
||||
def linear_scenario(t_scale: float) -> list[dict]:
|
||||
return [
|
||||
{"date": d, "value": round(last_val + t_scale * trend * (i + 1), 2)}
|
||||
for i, d in enumerate(future_dates)
|
||||
]
|
||||
|
||||
return {
|
||||
"history": history,
|
||||
"projections": {
|
||||
"conservative": linear_scenario(0.5),
|
||||
"base": linear_scenario(1.0),
|
||||
"optimistic": linear_scenario(1.5),
|
||||
},
|
||||
"insufficient_data": False,
|
||||
}
|
||||
91
backend/app/ml/spending_forecast.py
Normal file
91
backend/app/ml/spending_forecast.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from datetime import date
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
MIN_POINTS = 3
|
||||
FORECAST_MONTHS = 3
|
||||
|
||||
|
||||
def _next_month_starts(from_date: date, n: int) -> list[str]:
|
||||
months = []
|
||||
d = (from_date.replace(day=1) + relativedelta(months=1))
|
||||
for _ in range(n):
|
||||
months.append(d.strftime("%Y-%m-%d"))
|
||||
d += relativedelta(months=1)
|
||||
return months
|
||||
|
||||
|
||||
def _fit_holt(values: list[float], n: int) -> tuple[list[float], list[float], list[float]]:
|
||||
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
||||
|
||||
try:
|
||||
if len(values) >= 12:
|
||||
model = ExponentialSmoothing(values, trend="add", seasonal="add", seasonal_periods=12)
|
||||
elif len(values) >= 4:
|
||||
model = ExponentialSmoothing(values, trend="add", seasonal=None)
|
||||
else:
|
||||
model = ExponentialSmoothing(values, trend=None, seasonal=None)
|
||||
|
||||
fit = model.fit(optimized=True, disp=False)
|
||||
forecast = fit.forecast(n)
|
||||
sigma = float(np.std(fit.resid)) if len(fit.resid) > 1 else float(np.mean(values) * 0.15)
|
||||
lower = np.maximum(0, forecast - 1.28 * sigma)
|
||||
upper = forecast + 1.28 * sigma
|
||||
return forecast.tolist(), lower.tolist(), upper.tolist()
|
||||
except Exception:
|
||||
avg = float(np.mean(values))
|
||||
sigma = float(np.std(values)) if len(values) > 1 else avg * 0.15
|
||||
return [avg] * n, [max(0, avg - 1.28 * sigma)] * n, [(avg + 1.28 * sigma)] * n
|
||||
|
||||
|
||||
def forecast_spending(df: pd.DataFrame) -> list[dict]:
|
||||
"""
|
||||
df columns: category_id, category_name, ds (monthly), y (amount)
|
||||
Returns list of category forecast dicts.
|
||||
"""
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
today = date.today()
|
||||
future_dates = _next_month_starts(today, FORECAST_MONTHS)
|
||||
results = []
|
||||
|
||||
for (cat_id, cat_name), group in df.groupby(["category_id", "category_name"]):
|
||||
group = group.sort_values("ds")
|
||||
values = group["y"].tolist()
|
||||
actuals = [
|
||||
{"date": row["ds"].strftime("%Y-%m-%d"), "amount": row["y"]}
|
||||
for _, row in group.iterrows()
|
||||
]
|
||||
|
||||
if len(values) < MIN_POINTS:
|
||||
avg = float(np.mean(values))
|
||||
forecast_pts = [
|
||||
{"date": d, "amount": round(avg, 2), "lower": round(avg * 0.7, 2), "upper": round(avg * 1.3, 2)}
|
||||
for d in future_dates
|
||||
]
|
||||
else:
|
||||
fcast, lower, upper = _fit_holt(values, FORECAST_MONTHS)
|
||||
forecast_pts = [
|
||||
{"date": d, "amount": round(max(0, f), 2), "lower": round(l, 2), "upper": round(u, 2)}
|
||||
for d, f, l, u in zip(future_dates, fcast, lower, upper)
|
||||
]
|
||||
|
||||
results.append({
|
||||
"category_id": cat_id,
|
||||
"category_name": cat_name,
|
||||
"monthly_avg": round(float(np.mean(values)), 2),
|
||||
"actuals": actuals[-6:], # last 6 months for display
|
||||
"forecast": forecast_pts,
|
||||
})
|
||||
|
||||
# Sort by monthly_avg descending (highest spend first)
|
||||
results.sort(key=lambda x: x["monthly_avg"], reverse=True)
|
||||
return results
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
59
backend/app/schemas/account.py
Normal file
59
backend/app/schemas/account.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
AccountType = Literal[
|
||||
"checking", "savings", "cash_isa", "stocks_shares_isa",
|
||||
"credit_card", "investment", "cash", "crypto_wallet",
|
||||
"loan", "mortgage", "pension", "other"
|
||||
]
|
||||
|
||||
|
||||
class AccountCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
institution: str | None = None
|
||||
type: AccountType
|
||||
currency: str = Field(default="GBP", min_length=3, max_length=10)
|
||||
credit_limit: Decimal | None = None
|
||||
interest_rate: Decimal | None = None
|
||||
include_in_net_worth: bool = True
|
||||
color: str = Field(default="#6366f1", pattern=r"^#[0-9a-fA-F]{6}$")
|
||||
icon: str | None = None
|
||||
notes: str | None = None
|
||||
opening_balance: Decimal = Field(default=Decimal("0"))
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
institution: str | None = None
|
||||
opening_balance: Decimal | None = None
|
||||
credit_limit: Decimal | None = None
|
||||
interest_rate: Decimal | None = None
|
||||
include_in_net_worth: bool | None = None
|
||||
is_active: bool | None = None
|
||||
color: str | None = Field(default=None, pattern=r"^#[0-9a-fA-F]{6}$")
|
||||
icon: str | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
institution: str | None
|
||||
type: str
|
||||
currency: str
|
||||
current_balance: Decimal
|
||||
credit_limit: Decimal | None
|
||||
interest_rate: Decimal | None
|
||||
is_active: bool
|
||||
include_in_net_worth: bool
|
||||
color: str
|
||||
icon: str | None
|
||||
notes: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
64
backend/app/schemas/auth.py
Normal file
64
backend/app/schemas/auth.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
display_name: str
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
if len(v) < 12:
|
||||
raise ValueError("Password must be at least 12 characters")
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError("Password must contain an uppercase letter")
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError("Password must contain a digit")
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class TOTPChallengeResponse(BaseModel):
|
||||
totp_required: bool = True
|
||||
challenge_token: str
|
||||
|
||||
|
||||
class TOTPLoginRequest(BaseModel):
|
||||
challenge_token: str
|
||||
totp_code: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int # seconds
|
||||
|
||||
|
||||
class TOTPSetupResponse(BaseModel):
|
||||
secret: str
|
||||
qr_code_png_b64: str
|
||||
backup_codes: list[str]
|
||||
|
||||
|
||||
class TOTPVerifyRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
id: uuid.UUID
|
||||
ip_address: str | None
|
||||
user_agent: str | None
|
||||
last_active_at: datetime
|
||||
expires_at: datetime
|
||||
created_at: datetime
|
||||
is_current: bool = False
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
65
backend/app/schemas/budget.py
Normal file
65
backend/app/schemas/budget.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import uuid
|
||||
from datetime import date as DateType, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
BudgetPeriod = Literal["weekly", "monthly", "quarterly", "yearly"]
|
||||
|
||||
|
||||
class BudgetCreate(BaseModel):
|
||||
category_id: uuid.UUID
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
amount: Decimal = Field(..., gt=0)
|
||||
currency: str = Field(default="GBP", min_length=3, max_length=10)
|
||||
period: BudgetPeriod = "monthly"
|
||||
start_date: DateType
|
||||
end_date: DateType | None = None
|
||||
rollover: bool = False
|
||||
alert_threshold: Decimal = Field(default=Decimal("80"), ge=0, le=100)
|
||||
|
||||
|
||||
class BudgetUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=200)
|
||||
amount: Decimal | None = Field(default=None, gt=0)
|
||||
period: BudgetPeriod | None = None
|
||||
end_date: DateType | None = None
|
||||
rollover: bool | None = None
|
||||
alert_threshold: Decimal | None = Field(default=None, ge=0, le=100)
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class BudgetResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
category_id: uuid.UUID
|
||||
name: str
|
||||
amount: Decimal
|
||||
currency: str
|
||||
period: str
|
||||
start_date: DateType
|
||||
end_date: DateType | None
|
||||
rollover: bool
|
||||
alert_threshold: Decimal
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BudgetSummaryItem(BaseModel):
|
||||
budget_id: uuid.UUID
|
||||
budget_name: str
|
||||
category_id: uuid.UUID
|
||||
category_name: str
|
||||
period: str
|
||||
budget_amount: Decimal
|
||||
spent_amount: Decimal
|
||||
remaining_amount: Decimal
|
||||
percent_used: Decimal
|
||||
is_over_budget: bool
|
||||
alert_triggered: bool
|
||||
currency: str
|
||||
period_start: DateType
|
||||
period_end: DateType
|
||||
103
backend/app/schemas/investment.py
Normal file
103
backend/app/schemas/investment.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import uuid
|
||||
from datetime import date as DateType, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
InvestmentTxnType = Literal["buy", "sell", "dividend", "split", "fee", "transfer_in", "transfer_out"]
|
||||
|
||||
|
||||
class AssetSearch(BaseModel):
|
||||
id: uuid.UUID
|
||||
symbol: str
|
||||
name: str
|
||||
type: str
|
||||
currency: str
|
||||
exchange: str | None
|
||||
last_price: Decimal | None
|
||||
price_change_24h: Decimal | None
|
||||
data_source: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AssetPricePoint(BaseModel):
|
||||
date: DateType
|
||||
open: Decimal | None
|
||||
high: Decimal | None
|
||||
low: Decimal | None
|
||||
close: Decimal
|
||||
volume: Decimal | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class HoldingCreate(BaseModel):
|
||||
account_id: uuid.UUID
|
||||
asset_id: uuid.UUID
|
||||
quantity: Decimal = Field(..., gt=0)
|
||||
avg_cost_basis: Decimal = Field(..., ge=0)
|
||||
currency: str = Field(default="GBP", min_length=3, max_length=10)
|
||||
|
||||
|
||||
class HoldingResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
account_id: uuid.UUID
|
||||
asset_id: uuid.UUID
|
||||
symbol: str
|
||||
asset_name: str
|
||||
asset_type: str
|
||||
quantity: Decimal
|
||||
avg_cost_basis: Decimal
|
||||
current_price: Decimal | None
|
||||
current_value: Decimal | None
|
||||
cost_basis_total: Decimal
|
||||
unrealised_gain: Decimal | None
|
||||
unrealised_gain_pct: Decimal | None
|
||||
currency: str
|
||||
price_change_24h: Decimal | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class InvestmentTxnCreate(BaseModel):
|
||||
holding_id: uuid.UUID
|
||||
type: InvestmentTxnType
|
||||
quantity: Decimal = Field(..., ge=0)
|
||||
price: Decimal = Field(..., ge=0)
|
||||
fees: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
currency: str = Field(default="GBP", min_length=3, max_length=10)
|
||||
date: DateType
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class InvestmentTxnResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
holding_id: uuid.UUID
|
||||
type: str
|
||||
quantity: Decimal
|
||||
price: Decimal
|
||||
fees: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
date: DateType
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class PortfolioSummary(BaseModel):
|
||||
total_value: Decimal
|
||||
total_cost: Decimal
|
||||
total_gain: Decimal
|
||||
total_gain_pct: Decimal
|
||||
currency: str
|
||||
holdings: list[HoldingResponse]
|
||||
|
||||
|
||||
class PerformanceMetrics(BaseModel):
|
||||
twrr: Decimal | None
|
||||
total_return: Decimal
|
||||
total_return_pct: Decimal
|
||||
currency: str
|
||||
96
backend/app/schemas/report.py
Normal file
96
backend/app/schemas/report.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from datetime import date as DateType
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class NetWorthPoint(BaseModel):
|
||||
date: DateType
|
||||
total_assets: Decimal
|
||||
total_liabilities: Decimal
|
||||
net_worth: Decimal
|
||||
base_currency: str
|
||||
|
||||
|
||||
class NetWorthReport(BaseModel):
|
||||
points: list[NetWorthPoint]
|
||||
current_net_worth: Decimal
|
||||
change_30d: Decimal
|
||||
change_30d_pct: Decimal
|
||||
base_currency: str
|
||||
|
||||
|
||||
class IncomeExpensePoint(BaseModel):
|
||||
month: str # "2024-01"
|
||||
income: Decimal
|
||||
expenses: Decimal
|
||||
net: Decimal
|
||||
|
||||
|
||||
class IncomeExpenseReport(BaseModel):
|
||||
points: list[IncomeExpensePoint]
|
||||
total_income: Decimal
|
||||
total_expenses: Decimal
|
||||
avg_monthly_income: Decimal
|
||||
avg_monthly_expenses: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class CashFlowPoint(BaseModel):
|
||||
date: DateType
|
||||
inflow: Decimal
|
||||
outflow: Decimal
|
||||
net: Decimal
|
||||
running_balance: Decimal
|
||||
|
||||
|
||||
class CashFlowReport(BaseModel):
|
||||
points: list[CashFlowPoint]
|
||||
total_inflow: Decimal
|
||||
total_outflow: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class CategoryBreakdownItem(BaseModel):
|
||||
category_id: str | None
|
||||
category_name: str
|
||||
amount: Decimal
|
||||
percent: Decimal
|
||||
transaction_count: int
|
||||
|
||||
|
||||
class CategoryBreakdownReport(BaseModel):
|
||||
items: list[CategoryBreakdownItem]
|
||||
total: Decimal
|
||||
currency: str
|
||||
date_from: DateType
|
||||
date_to: DateType
|
||||
|
||||
|
||||
class BudgetVsActualItem(BaseModel):
|
||||
budget_id: str
|
||||
budget_name: str
|
||||
category_name: str
|
||||
budgeted: Decimal
|
||||
actual: Decimal
|
||||
variance: Decimal
|
||||
percent_used: Decimal
|
||||
|
||||
|
||||
class BudgetVsActualReport(BaseModel):
|
||||
items: list[BudgetVsActualItem]
|
||||
total_budgeted: Decimal
|
||||
total_actual: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class SpendingTrendPoint(BaseModel):
|
||||
month: str
|
||||
category_name: str
|
||||
amount: Decimal
|
||||
|
||||
|
||||
class SpendingTrendsReport(BaseModel):
|
||||
points: list[SpendingTrendPoint]
|
||||
categories: list[str]
|
||||
currency: str
|
||||
77
backend/app/schemas/transaction.py
Normal file
77
backend/app/schemas/transaction.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import uuid
|
||||
from datetime import date as DateType, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TransactionType = Literal["income", "expense", "transfer", "investment"]
|
||||
TransactionStatus = Literal["pending", "cleared", "reconciled", "void"]
|
||||
|
||||
|
||||
class TransactionCreate(BaseModel):
|
||||
account_id: uuid.UUID
|
||||
transfer_account_id: uuid.UUID | None = None
|
||||
category_id: uuid.UUID | None = None
|
||||
type: TransactionType
|
||||
status: TransactionStatus = "cleared"
|
||||
amount: Decimal
|
||||
currency: str = Field(default="GBP", min_length=3, max_length=10)
|
||||
date: DateType
|
||||
description: str = Field(..., min_length=1, max_length=500)
|
||||
merchant: str | None = None
|
||||
notes: str | None = None
|
||||
tags: list[str] = []
|
||||
is_recurring: bool = False
|
||||
recurring_rule: dict | None = None
|
||||
|
||||
|
||||
class TransactionUpdate(BaseModel):
|
||||
category_id: uuid.UUID | None = None
|
||||
status: TransactionStatus | None = None
|
||||
amount: Decimal | None = None
|
||||
date: DateType | None = None
|
||||
description: str | None = Field(default=None, min_length=1, max_length=500)
|
||||
merchant: str | None = None
|
||||
notes: str | None = None
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
class TransactionFilter(BaseModel):
|
||||
account_id: uuid.UUID | None = None
|
||||
category_id: uuid.UUID | None = None
|
||||
type: TransactionType | None = None
|
||||
status: TransactionStatus | None = None
|
||||
date_from: DateType | None = None
|
||||
date_to: DateType | None = None
|
||||
min_amount: Decimal | None = None
|
||||
max_amount: Decimal | None = None
|
||||
search: str | None = None
|
||||
tags: list[str] = []
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=50, ge=1, le=200)
|
||||
|
||||
|
||||
class TransactionResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
account_id: uuid.UUID
|
||||
transfer_account_id: uuid.UUID | None
|
||||
category_id: uuid.UUID | None
|
||||
type: str
|
||||
status: str
|
||||
amount: Decimal
|
||||
amount_base: Decimal | None
|
||||
currency: str
|
||||
base_currency: str
|
||||
exchange_rate: Decimal | None
|
||||
date: DateType
|
||||
description: str
|
||||
merchant: str | None
|
||||
notes: str | None
|
||||
tags: list[str]
|
||||
is_recurring: bool
|
||||
attachment_refs: list[dict] = []
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
195
backend/app/services/account_service.py
Normal file
195
backend/app/services/account_service.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import encrypt_field, decrypt_field
|
||||
from app.db.models.account import Account
|
||||
from app.db.models.transaction import Transaction
|
||||
from app.schemas.account import AccountCreate, AccountUpdate
|
||||
|
||||
# Account types that are liabilities (balance is negative contribution to net worth)
|
||||
LIABILITY_TYPES = {"credit_card", "loan", "mortgage"}
|
||||
|
||||
|
||||
class AccountError(Exception):
|
||||
def __init__(self, detail: str, status_code: int = 400):
|
||||
self.detail = detail
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def _encrypt(value: str | None) -> bytes | None:
|
||||
if value is None:
|
||||
return None
|
||||
return encrypt_field(value)
|
||||
|
||||
|
||||
def _decrypt(data: bytes | None) -> str | None:
|
||||
if not data:
|
||||
return None
|
||||
return decrypt_field(data)
|
||||
|
||||
|
||||
def _to_response(account: Account) -> dict:
|
||||
return {
|
||||
"id": account.id,
|
||||
"name": _decrypt(account.name_enc) or "",
|
||||
"institution": _decrypt(account.institution_enc),
|
||||
"type": account.type,
|
||||
"currency": account.currency,
|
||||
"current_balance": account.current_balance,
|
||||
"credit_limit": account.credit_limit,
|
||||
"interest_rate": account.interest_rate,
|
||||
"is_active": account.is_active,
|
||||
"include_in_net_worth": account.include_in_net_worth,
|
||||
"color": account.color,
|
||||
"icon": account.icon,
|
||||
"notes": _decrypt(account.notes_enc),
|
||||
"created_at": account.created_at,
|
||||
"updated_at": account.updated_at,
|
||||
}
|
||||
|
||||
|
||||
async def create_account(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
data: AccountCreate,
|
||||
) -> dict:
|
||||
now = datetime.now(timezone.utc)
|
||||
account = Account(
|
||||
user_id=user_id,
|
||||
name_enc=encrypt_field(data.name),
|
||||
institution_enc=_encrypt(data.institution),
|
||||
type=data.type,
|
||||
currency=data.currency,
|
||||
current_balance=data.opening_balance,
|
||||
credit_limit=data.credit_limit,
|
||||
interest_rate=data.interest_rate,
|
||||
include_in_net_worth=data.include_in_net_worth,
|
||||
color=data.color,
|
||||
icon=data.icon,
|
||||
notes_enc=_encrypt(data.notes),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
return _to_response(account)
|
||||
|
||||
|
||||
async def list_accounts(db: AsyncSession, user_id: uuid.UUID) -> list[dict]:
|
||||
result = await db.execute(
|
||||
select(Account).where(
|
||||
Account.user_id == user_id,
|
||||
Account.deleted_at.is_(None),
|
||||
).order_by(Account.created_at)
|
||||
)
|
||||
return [_to_response(a) for a in result.scalars()]
|
||||
|
||||
|
||||
async def get_account(db: AsyncSession, account_id: uuid.UUID, user_id: uuid.UUID) -> Account:
|
||||
result = await db.execute(
|
||||
select(Account).where(
|
||||
Account.id == account_id,
|
||||
Account.user_id == user_id,
|
||||
Account.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise AccountError("Account not found", status_code=404)
|
||||
return account
|
||||
|
||||
|
||||
async def update_account(
|
||||
db: AsyncSession,
|
||||
account_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: AccountUpdate,
|
||||
) -> dict:
|
||||
account = await get_account(db, account_id, user_id)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if data.name is not None:
|
||||
account.name_enc = encrypt_field(data.name)
|
||||
if data.institution is not None:
|
||||
account.institution_enc = _encrypt(data.institution)
|
||||
if data.opening_balance is not None:
|
||||
account.current_balance = data.opening_balance
|
||||
if data.credit_limit is not None:
|
||||
account.credit_limit = data.credit_limit
|
||||
if data.interest_rate is not None:
|
||||
account.interest_rate = data.interest_rate
|
||||
if data.include_in_net_worth is not None:
|
||||
account.include_in_net_worth = data.include_in_net_worth
|
||||
if data.is_active is not None:
|
||||
account.is_active = data.is_active
|
||||
if data.color is not None:
|
||||
account.color = data.color
|
||||
if data.icon is not None:
|
||||
account.icon = data.icon
|
||||
if data.notes is not None:
|
||||
account.notes_enc = _encrypt(data.notes)
|
||||
|
||||
account.updated_at = now
|
||||
await db.flush()
|
||||
return _to_response(account)
|
||||
|
||||
|
||||
async def delete_account(
|
||||
db: AsyncSession,
|
||||
account_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
) -> None:
|
||||
account = await get_account(db, account_id, user_id)
|
||||
account.deleted_at = datetime.now(timezone.utc)
|
||||
account.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def recalculate_balance(db: AsyncSession, account_id: uuid.UUID) -> None:
|
||||
"""Recompute current_balance from all non-deleted transactions."""
|
||||
result = await db.execute(
|
||||
select(func.sum(Transaction.amount)).where(
|
||||
Transaction.account_id == account_id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
total = result.scalar_one_or_none() or Decimal("0")
|
||||
|
||||
account = await db.get(Account, account_id)
|
||||
if account:
|
||||
account.current_balance = total
|
||||
account.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def get_net_worth(db: AsyncSession, user_id: uuid.UUID, base_currency: str) -> dict:
|
||||
accounts = await db.execute(
|
||||
select(Account).where(
|
||||
Account.user_id == user_id,
|
||||
Account.include_in_net_worth == True,
|
||||
Account.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
total_assets = Decimal("0")
|
||||
total_liabilities = Decimal("0")
|
||||
|
||||
for account in accounts.scalars():
|
||||
# TODO Phase 3: convert to base_currency via FX rates
|
||||
bal = account.current_balance
|
||||
if account.type in LIABILITY_TYPES:
|
||||
total_liabilities += abs(bal)
|
||||
else:
|
||||
total_assets += bal
|
||||
|
||||
return {
|
||||
"total_assets": total_assets,
|
||||
"total_liabilities": total_liabilities,
|
||||
"net_worth": total_assets - total_liabilities,
|
||||
"base_currency": base_currency,
|
||||
}
|
||||
258
backend/app/services/auth_service.py
Normal file
258
backend/app/services/auth_service.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""
|
||||
Authentication service: register, login, TOTP, sessions, brute-force protection.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decrypt_field,
|
||||
decode_token,
|
||||
encrypt_field,
|
||||
generate_backup_codes,
|
||||
generate_csrf_token,
|
||||
generate_totp_qr_png,
|
||||
generate_totp_secret,
|
||||
hash_password,
|
||||
hash_token,
|
||||
verify_password,
|
||||
verify_totp,
|
||||
)
|
||||
from app.db.models.session import Session
|
||||
from app.db.models.user import User
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
def __init__(self, detail: str, status_code: int = 401):
|
||||
self.detail = detail
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
async def _lockout_key(email: str) -> str:
|
||||
return f"lockout:{email}"
|
||||
|
||||
|
||||
async def _check_and_record_failure(redis: Redis, email: str, settings) -> None:
|
||||
key = await _lockout_key(email)
|
||||
attempts = await redis.incr(key)
|
||||
if attempts == 1:
|
||||
await redis.expire(key, settings.lockout_base_seconds)
|
||||
if attempts >= settings.max_login_attempts:
|
||||
lockout_seconds = settings.lockout_base_seconds * (2 ** (attempts - settings.max_login_attempts))
|
||||
await redis.expire(key, min(lockout_seconds, 86400)) # cap at 24h
|
||||
|
||||
|
||||
async def _is_locked_out(redis: Redis, email: str) -> bool:
|
||||
key = await _lockout_key(email)
|
||||
val = await redis.get(key)
|
||||
if val is None:
|
||||
return False
|
||||
settings = get_settings()
|
||||
return int(val) >= settings.max_login_attempts
|
||||
|
||||
|
||||
async def register_user(db: AsyncSession, email: str, password: str, display_name: str) -> User:
|
||||
settings = get_settings()
|
||||
|
||||
# Single-user: block registration if user already exists
|
||||
if not settings.allow_registration:
|
||||
count = await db.scalar(select(func.count()).select_from(User).where(User.deleted_at.is_(None)))
|
||||
if count and count > 0:
|
||||
raise AuthError("Registration is disabled", status_code=403)
|
||||
|
||||
existing = await db.scalar(select(User).where(User.email == email))
|
||||
if existing:
|
||||
raise AuthError("Email already registered", status_code=409)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=hash_password(password),
|
||||
display_name=display_name,
|
||||
base_currency=settings.base_currency,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
email: str,
|
||||
password: str,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> tuple[User, str, str] | tuple[User, None, None]:
|
||||
"""
|
||||
Returns (user, access_token, refresh_token) if no TOTP required,
|
||||
or (user, None, None) if TOTP challenge needed.
|
||||
Raises AuthError on failure.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if await _is_locked_out(redis, email):
|
||||
raise AuthError("Account temporarily locked due to too many failed attempts", status_code=429)
|
||||
|
||||
user = await db.scalar(
|
||||
select(User).where(User.email == email, User.deleted_at.is_(None))
|
||||
)
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
await _check_and_record_failure(redis, email, settings)
|
||||
raise AuthError("Invalid email or password")
|
||||
|
||||
# Clear lockout on success
|
||||
await redis.delete(await _lockout_key(email))
|
||||
|
||||
if user.totp_enabled:
|
||||
return user, None, None # Caller creates challenge token
|
||||
|
||||
tokens = await _create_session(db, user, ip, user_agent)
|
||||
return user, *tokens
|
||||
|
||||
|
||||
async def _create_session(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> tuple[str, str]:
|
||||
settings = get_settings()
|
||||
access_token = create_access_token(str(user.id))
|
||||
refresh_token = create_refresh_token(str(user.id))
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
session = Session(
|
||||
user_id=user.id,
|
||||
token_hash=hash_token(access_token),
|
||||
ip_address=ip,
|
||||
user_agent=user_agent,
|
||||
last_active_at=now,
|
||||
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
|
||||
created_at=now,
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
|
||||
# Update user login info
|
||||
user.last_login_at = now
|
||||
user.last_login_ip = ip
|
||||
user.updated_at = now
|
||||
|
||||
return access_token, refresh_token
|
||||
|
||||
|
||||
async def complete_totp_login(
|
||||
db: AsyncSession,
|
||||
challenge_token: str,
|
||||
totp_code: str,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> tuple[str, str]:
|
||||
try:
|
||||
payload = decode_token(challenge_token, token_type="totp_challenge")
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
except (JWTError, ValueError, KeyError):
|
||||
raise AuthError("Invalid or expired challenge token")
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
if not user or not user.totp_enabled or not user.totp_secret_enc:
|
||||
raise AuthError("Invalid challenge")
|
||||
|
||||
secret = decrypt_field(bytes.fromhex(user.totp_secret_enc) if isinstance(user.totp_secret_enc, str) else user.totp_secret_enc)
|
||||
if not verify_totp(secret, totp_code):
|
||||
raise AuthError("Invalid TOTP code")
|
||||
|
||||
return await _create_session(db, user, ip, user_agent)
|
||||
|
||||
|
||||
def create_totp_challenge_token(user_id: uuid.UUID) -> str:
|
||||
from app.core.security import create_access_token
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timezone
|
||||
from app.config import get_settings
|
||||
from jose import jwt
|
||||
from pathlib import Path
|
||||
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=5),
|
||||
"type": "totp_challenge",
|
||||
}
|
||||
private_key = Path(settings.jwt_private_key_file).read_text()
|
||||
return jwt.encode(payload, private_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
async def setup_totp(user: User, db: AsyncSession) -> tuple[str, str, list[str]]:
|
||||
"""Generate TOTP secret, QR code, and backup codes. Does not enable TOTP yet."""
|
||||
secret = generate_totp_secret()
|
||||
qr_png = generate_totp_qr_png(secret, user.email)
|
||||
backup_codes = generate_backup_codes(8)
|
||||
return secret, base64.b64encode(qr_png).decode(), backup_codes
|
||||
|
||||
|
||||
async def enable_totp(user: User, db: AsyncSession, secret: str, code: str) -> None:
|
||||
if not verify_totp(secret, code):
|
||||
raise AuthError("Invalid TOTP code — setup failed", status_code=400)
|
||||
|
||||
encrypted = encrypt_field(secret)
|
||||
user.totp_secret_enc = encrypted.hex()
|
||||
user.totp_enabled = True
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def disable_totp(user: User, db: AsyncSession, password: str) -> None:
|
||||
if not verify_password(password, user.password_hash):
|
||||
raise AuthError("Incorrect password", status_code=400)
|
||||
user.totp_secret_enc = None
|
||||
user.totp_enabled = False
|
||||
user.totp_backup_codes_enc = None
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def revoke_session(db: AsyncSession, session_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
session = await db.get(Session, session_id)
|
||||
if not session or session.user_id != user_id:
|
||||
raise AuthError("Session not found", status_code=404)
|
||||
session.revoked_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def revoke_all_sessions(db: AsyncSession, user_id: uuid.UUID, except_token_hash: str | None = None) -> None:
|
||||
from sqlalchemy import update
|
||||
stmt = (
|
||||
update(Session)
|
||||
.where(Session.user_id == user_id, Session.revoked_at.is_(None))
|
||||
)
|
||||
if except_token_hash:
|
||||
stmt = stmt.where(Session.token_hash != except_token_hash)
|
||||
stmt = stmt.values(revoked_at=datetime.now(timezone.utc))
|
||||
await db.execute(stmt)
|
||||
|
||||
|
||||
async def get_sessions(db: AsyncSession, user_id: uuid.UUID) -> list[Session]:
|
||||
result = await db.execute(
|
||||
select(Session).where(
|
||||
Session.user_id == user_id,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.expires_at > datetime.now(timezone.utc),
|
||||
).order_by(Session.created_at.desc())
|
||||
)
|
||||
return list(result.scalars())
|
||||
137
backend/app/services/budget_service.py
Normal file
137
backend/app/services/budget_service.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import uuid
|
||||
from datetime import date, datetime, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.budget import Budget
|
||||
from app.db.models.category import Category
|
||||
from app.db.models.transaction import Transaction
|
||||
from app.schemas.budget import BudgetCreate, BudgetSummaryItem, BudgetUpdate
|
||||
|
||||
|
||||
def _period_bounds(period: str, ref: date) -> tuple[date, date]:
|
||||
if period == "weekly":
|
||||
start = ref - relativedelta(days=ref.weekday())
|
||||
end = start + relativedelta(days=6)
|
||||
elif period == "monthly":
|
||||
start = ref.replace(day=1)
|
||||
end = (start + relativedelta(months=1)) - relativedelta(days=1)
|
||||
elif period == "quarterly":
|
||||
q = (ref.month - 1) // 3
|
||||
start = date(ref.year, q * 3 + 1, 1)
|
||||
end = (start + relativedelta(months=3)) - relativedelta(days=1)
|
||||
else: # yearly
|
||||
start = date(ref.year, 1, 1)
|
||||
end = date(ref.year, 12, 31)
|
||||
return start, end
|
||||
|
||||
|
||||
async def create_budget(db: AsyncSession, user_id: uuid.UUID, data: BudgetCreate) -> Budget:
|
||||
now = datetime.now(timezone.utc)
|
||||
budget = Budget(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
category_id=data.category_id,
|
||||
name=data.name,
|
||||
amount=data.amount,
|
||||
currency=data.currency,
|
||||
period=data.period,
|
||||
start_date=data.start_date,
|
||||
end_date=data.end_date,
|
||||
rollover=data.rollover,
|
||||
alert_threshold=data.alert_threshold,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(budget)
|
||||
await db.flush()
|
||||
await db.refresh(budget)
|
||||
return budget
|
||||
|
||||
|
||||
async def list_budgets(db: AsyncSession, user_id: uuid.UUID, active_only: bool = True) -> list[Budget]:
|
||||
q = select(Budget).where(Budget.user_id == user_id)
|
||||
if active_only:
|
||||
q = q.where(Budget.is_active == True) # noqa: E712
|
||||
q = q.order_by(Budget.name)
|
||||
result = await db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_budget(db: AsyncSession, user_id: uuid.UUID, budget_id: uuid.UUID) -> Budget | None:
|
||||
result = await db.execute(
|
||||
select(Budget).where(Budget.id == budget_id, Budget.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def update_budget(db: AsyncSession, budget: Budget, data: BudgetUpdate) -> Budget:
|
||||
for field, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(budget, field, value)
|
||||
budget.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
await db.refresh(budget)
|
||||
return budget
|
||||
|
||||
|
||||
async def delete_budget(db: AsyncSession, budget: Budget) -> None:
|
||||
await db.delete(budget)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def get_budget_summary(db: AsyncSession, user_id: uuid.UUID) -> list[BudgetSummaryItem]:
|
||||
budgets = await list_budgets(db, user_id, active_only=True)
|
||||
today = date.today()
|
||||
items: list[BudgetSummaryItem] = []
|
||||
|
||||
for budget in budgets:
|
||||
period_start, period_end = _period_bounds(budget.period, today)
|
||||
|
||||
# Fetch category name
|
||||
cat_result = await db.execute(select(Category).where(Category.id == budget.category_id))
|
||||
category = cat_result.scalar_one_or_none()
|
||||
cat_name = category.name if category else "Unknown"
|
||||
|
||||
# Sum actual spending in this period
|
||||
spent_result = await db.execute(
|
||||
select(func.coalesce(func.sum(func.abs(Transaction.amount)), Decimal("0")))
|
||||
.where(
|
||||
and_(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.category_id == budget.category_id,
|
||||
Transaction.type == "expense",
|
||||
Transaction.status != "void",
|
||||
Transaction.date >= period_start,
|
||||
Transaction.date <= period_end,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
spent = Decimal(str(spent_result.scalar() or 0))
|
||||
remaining = budget.amount - spent
|
||||
pct = (spent / budget.amount * 100) if budget.amount > 0 else Decimal("0")
|
||||
|
||||
items.append(
|
||||
BudgetSummaryItem(
|
||||
budget_id=budget.id,
|
||||
budget_name=budget.name,
|
||||
category_id=budget.category_id,
|
||||
category_name=cat_name,
|
||||
period=budget.period,
|
||||
budget_amount=budget.amount,
|
||||
spent_amount=spent,
|
||||
remaining_amount=remaining,
|
||||
percent_used=pct.quantize(Decimal("0.01")),
|
||||
is_over_budget=spent > budget.amount,
|
||||
alert_triggered=pct >= budget.alert_threshold,
|
||||
currency=budget.currency,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
135
backend/app/services/category_service.py
Normal file
135
backend/app/services/category_service.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.category import Category
|
||||
|
||||
SYSTEM_CATEGORIES = [
|
||||
# Income
|
||||
{"name": "Salary", "type": "income", "icon": "briefcase", "color": "#22c55e"},
|
||||
{"name": "Freelance", "type": "income", "icon": "laptop", "color": "#22c55e"},
|
||||
{"name": "Investment Income", "type": "income", "icon": "trending-up", "color": "#22c55e"},
|
||||
{"name": "Rental Income", "type": "income", "icon": "home", "color": "#22c55e"},
|
||||
{"name": "Benefits", "type": "income", "icon": "shield", "color": "#22c55e"},
|
||||
{"name": "Other Income", "type": "income", "icon": "plus-circle", "color": "#22c55e"},
|
||||
# Expenses — Housing
|
||||
{"name": "Rent / Mortgage", "type": "expense", "icon": "home", "color": "#6366f1"},
|
||||
{"name": "Council Tax", "type": "expense", "icon": "landmark", "color": "#6366f1"},
|
||||
{"name": "Home Insurance", "type": "expense", "icon": "shield", "color": "#6366f1"},
|
||||
{"name": "Home Maintenance", "type": "expense", "icon": "wrench", "color": "#6366f1"},
|
||||
# Utilities
|
||||
{"name": "Electricity", "type": "expense", "icon": "zap", "color": "#f59e0b"},
|
||||
{"name": "Gas", "type": "expense", "icon": "flame", "color": "#f59e0b"},
|
||||
{"name": "Water", "type": "expense", "icon": "droplets", "color": "#f59e0b"},
|
||||
{"name": "Internet", "type": "expense", "icon": "wifi", "color": "#f59e0b"},
|
||||
{"name": "Phone", "type": "expense", "icon": "smartphone", "color": "#f59e0b"},
|
||||
# Food
|
||||
{"name": "Groceries", "type": "expense", "icon": "shopping-cart", "color": "#ec4899"},
|
||||
{"name": "Eating Out", "type": "expense", "icon": "utensils", "color": "#ec4899"},
|
||||
{"name": "Coffee", "type": "expense", "icon": "coffee", "color": "#ec4899"},
|
||||
{"name": "Takeaway", "type": "expense", "icon": "package", "color": "#ec4899"},
|
||||
# Transport
|
||||
{"name": "Fuel", "type": "expense", "icon": "fuel", "color": "#0ea5e9"},
|
||||
{"name": "Public Transport", "type": "expense", "icon": "bus", "color": "#0ea5e9"},
|
||||
{"name": "Car Insurance", "type": "expense", "icon": "car", "color": "#0ea5e9"},
|
||||
{"name": "Car Maintenance", "type": "expense", "icon": "wrench", "color": "#0ea5e9"},
|
||||
{"name": "Parking", "type": "expense", "icon": "parking-circle", "color": "#0ea5e9"},
|
||||
{"name": "Taxi / Ride share", "type": "expense", "icon": "map-pin", "color": "#0ea5e9"},
|
||||
# Health
|
||||
{"name": "Healthcare", "type": "expense", "icon": "heart-pulse", "color": "#ef4444"},
|
||||
{"name": "Pharmacy", "type": "expense", "icon": "pill", "color": "#ef4444"},
|
||||
{"name": "Gym", "type": "expense", "icon": "dumbbell", "color": "#ef4444"},
|
||||
# Personal
|
||||
{"name": "Clothing", "type": "expense", "icon": "shirt", "color": "#a855f7"},
|
||||
{"name": "Personal Care", "type": "expense", "icon": "sparkles", "color": "#a855f7"},
|
||||
{"name": "Subscriptions", "type": "expense", "icon": "repeat", "color": "#a855f7"},
|
||||
{"name": "Entertainment", "type": "expense", "icon": "tv", "color": "#a855f7"},
|
||||
{"name": "Holidays", "type": "expense", "icon": "plane", "color": "#a855f7"},
|
||||
# Finance
|
||||
{"name": "Loan Repayment", "type": "expense", "icon": "credit-card", "color": "#64748b"},
|
||||
{"name": "Mortgage Payment", "type": "expense", "icon": "building", "color": "#64748b"},
|
||||
{"name": "Bank Charges", "type": "expense", "icon": "landmark", "color": "#64748b"},
|
||||
{"name": "Interest Paid", "type": "expense", "icon": "percent", "color": "#64748b"},
|
||||
# Savings
|
||||
{"name": "Savings", "type": "expense", "icon": "piggy-bank", "color": "#10b981"},
|
||||
{"name": "Investments", "type": "expense", "icon": "trending-up", "color": "#10b981"},
|
||||
# Other
|
||||
{"name": "Gifts", "type": "expense", "icon": "gift", "color": "#f97316"},
|
||||
{"name": "Education", "type": "expense", "icon": "graduation-cap", "color": "#f97316"},
|
||||
{"name": "Other Expense", "type": "expense", "icon": "more-horizontal", "color": "#64748b"},
|
||||
# Transfers
|
||||
{"name": "Transfer", "type": "transfer", "icon": "arrow-left-right", "color": "#94a3b8"},
|
||||
]
|
||||
|
||||
|
||||
async def seed_system_categories(db: AsyncSession) -> None:
|
||||
existing = await db.scalar(
|
||||
select(Category).where(Category.is_system == True).limit(1)
|
||||
)
|
||||
if existing:
|
||||
return
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for i, cat in enumerate(SYSTEM_CATEGORIES):
|
||||
db.add(Category(
|
||||
user_id=None,
|
||||
name=cat["name"],
|
||||
type=cat["type"],
|
||||
icon=cat.get("icon"),
|
||||
color=cat.get("color"),
|
||||
is_system=True,
|
||||
sort_order=i,
|
||||
created_at=now,
|
||||
))
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def list_categories(db: AsyncSession, user_id: uuid.UUID) -> list[dict]:
|
||||
result = await db.execute(
|
||||
select(Category).where(
|
||||
(Category.user_id == user_id) | (Category.user_id.is_(None))
|
||||
).order_by(Category.type, Category.sort_order, Category.name)
|
||||
)
|
||||
cats = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"type": c.type,
|
||||
"icon": c.icon,
|
||||
"color": c.color,
|
||||
"is_system": c.is_system,
|
||||
"parent_id": str(c.parent_id) if c.parent_id else None,
|
||||
"sort_order": c.sort_order,
|
||||
}
|
||||
for c in cats
|
||||
]
|
||||
|
||||
|
||||
async def create_category(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
name: str,
|
||||
type_: str,
|
||||
icon: str | None = None,
|
||||
color: str | None = None,
|
||||
parent_id: uuid.UUID | None = None,
|
||||
) -> dict:
|
||||
now = datetime.now(timezone.utc)
|
||||
cat = Category(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
type=type_,
|
||||
icon=icon,
|
||||
color=color,
|
||||
parent_id=parent_id,
|
||||
is_system=False,
|
||||
created_at=now,
|
||||
)
|
||||
db.add(cat)
|
||||
await db.flush()
|
||||
return {"id": str(cat.id), "name": cat.name, "type": cat.type, "icon": cat.icon, "color": cat.color, "is_system": False}
|
||||
237
backend/app/services/csv_detector.py
Normal file
237
backend/app/services/csv_detector.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""
|
||||
Auto-detect CSV bank export formats and produce a column mapping.
|
||||
Supports: Monzo, Starling, Revolut, Barclays, Lloyds, NatWest/RBS, HSBC, Santander.
|
||||
Falls back to a generic best-effort mapping for unknown formats.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class CsvMapping:
|
||||
date: str
|
||||
description: str
|
||||
amount: str | None = None # single signed amount column
|
||||
debit: str | None = None # separate debit column (positive value = money out)
|
||||
credit: str | None = None # separate credit column (positive value = money in)
|
||||
balance: str | None = None
|
||||
reference: str | None = None
|
||||
detected_format: str | None = None
|
||||
|
||||
def is_split(self) -> bool:
|
||||
return self.debit is not None and self.credit is not None
|
||||
|
||||
|
||||
KNOWN_FORMATS: list[dict] = [
|
||||
{
|
||||
"name": "Monzo",
|
||||
"detect": lambda h: {"transaction id", "emoji"}.issubset(h),
|
||||
"date": "Date",
|
||||
"description": "Name",
|
||||
"amount": "Amount",
|
||||
"balance": None,
|
||||
"reference": "Notes and #tags",
|
||||
},
|
||||
{
|
||||
"name": "Starling",
|
||||
"detect": lambda h: {"counter party", "spending category"}.issubset(h),
|
||||
"date": "Date",
|
||||
"description": "Counter Party",
|
||||
"amount": "Amount (GBP)",
|
||||
"balance": "Balance (GBP)",
|
||||
"reference": "Reference",
|
||||
},
|
||||
{
|
||||
"name": "Revolut",
|
||||
"detect": lambda h: {"product", "started date", "completed date"}.issubset(h),
|
||||
"date": "Started Date",
|
||||
"description": "Description",
|
||||
"amount": "Amount",
|
||||
"balance": "Balance",
|
||||
"reference": None,
|
||||
},
|
||||
{
|
||||
"name": "Barclays",
|
||||
"detect": lambda h: {"subcategory", "memo", "number"}.issubset(h),
|
||||
"date": "Date",
|
||||
"description": "Memo",
|
||||
"amount": "Amount",
|
||||
"balance": None,
|
||||
"reference": "Subcategory",
|
||||
},
|
||||
{
|
||||
"name": "Lloyds Bank",
|
||||
"detect": lambda h: {"transaction date", "debit amount", "credit amount", "transaction description"}.issubset(h),
|
||||
"date": "Transaction Date",
|
||||
"description": "Transaction Description",
|
||||
"debit": "Debit Amount",
|
||||
"credit": "Credit Amount",
|
||||
"balance": "Balance",
|
||||
"reference": None,
|
||||
},
|
||||
{
|
||||
"name": "Halifax",
|
||||
"detect": lambda h: {"transaction date", "debit amount", "credit amount", "transaction description"}.issubset(h),
|
||||
"date": "Transaction Date",
|
||||
"description": "Transaction Description",
|
||||
"debit": "Debit Amount",
|
||||
"credit": "Credit Amount",
|
||||
"balance": "Balance",
|
||||
"reference": None,
|
||||
},
|
||||
{
|
||||
"name": "NatWest / RBS",
|
||||
"detect": lambda h: {"date", "type", "description", "value", "balance"}.issubset(h) and "value" in h,
|
||||
"date": "Date",
|
||||
"description": "Description",
|
||||
"amount": "Value",
|
||||
"balance": "Balance",
|
||||
"reference": None,
|
||||
},
|
||||
{
|
||||
"name": "HSBC",
|
||||
"detect": lambda h: h == {"date", "description", "amount"} or h == {"date", "description", "debit", "credit", "balance"},
|
||||
"date": "Date",
|
||||
"description": "Description",
|
||||
"amount": "Amount",
|
||||
"balance": None,
|
||||
"reference": None,
|
||||
},
|
||||
{
|
||||
"name": "Santander",
|
||||
"detect": lambda h: {"date", "description", "debit", "credit", "balance"}.issubset(h),
|
||||
"date": "Date",
|
||||
"description": "Description",
|
||||
"debit": "Debit",
|
||||
"credit": "Credit",
|
||||
"balance": "Balance",
|
||||
"reference": None,
|
||||
},
|
||||
{
|
||||
"name": "Nationwide",
|
||||
"detect": lambda h: {"date", "transaction", "payments out", "payments in", "balance"}.issubset(h),
|
||||
"date": "Date",
|
||||
"description": "Transaction",
|
||||
"debit": "Payments Out",
|
||||
"credit": "Payments In",
|
||||
"balance": "Balance",
|
||||
"reference": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _normalise_headers(raw_headers: list[str]) -> dict[str, str]:
|
||||
"""Return {normalised_key: original_header}."""
|
||||
return {h.strip().lower(): h.strip() for h in raw_headers if h}
|
||||
|
||||
|
||||
def detect_format(raw_headers: list[str]) -> CsvMapping:
|
||||
norm = _normalise_headers(raw_headers)
|
||||
norm_set = set(norm.keys())
|
||||
|
||||
for fmt in KNOWN_FORMATS:
|
||||
if fmt["detect"](norm_set):
|
||||
# Map logical names → actual header using case-insensitive lookup
|
||||
def resolve(col: str | None) -> str | None:
|
||||
if col is None:
|
||||
return None
|
||||
return norm.get(col.strip().lower(), col)
|
||||
|
||||
if "debit" in fmt:
|
||||
return CsvMapping(
|
||||
date=resolve(fmt["date"]) or fmt["date"],
|
||||
description=resolve(fmt["description"]) or fmt["description"],
|
||||
debit=resolve(fmt["debit"]),
|
||||
credit=resolve(fmt["credit"]),
|
||||
balance=resolve(fmt.get("balance")),
|
||||
reference=resolve(fmt.get("reference")),
|
||||
detected_format=fmt["name"],
|
||||
)
|
||||
else:
|
||||
return CsvMapping(
|
||||
date=resolve(fmt["date"]) or fmt["date"],
|
||||
description=resolve(fmt["description"]) or fmt["description"],
|
||||
amount=resolve(fmt["amount"]),
|
||||
balance=resolve(fmt.get("balance")),
|
||||
reference=resolve(fmt.get("reference")),
|
||||
detected_format=fmt["name"],
|
||||
)
|
||||
|
||||
# Generic fallback: guess by common column name patterns
|
||||
return _generic_mapping(norm)
|
||||
|
||||
|
||||
def _generic_mapping(norm: dict[str, str]) -> CsvMapping:
|
||||
def find(*candidates: str) -> str | None:
|
||||
for c in candidates:
|
||||
if c in norm:
|
||||
return norm[c]
|
||||
return None
|
||||
|
||||
date_col = find("date", "transaction date", "trans date", "value date", "posting date")
|
||||
desc_col = find("description", "narrative", "details", "memo", "payee", "merchant", "name", "counter party")
|
||||
amt_col = find("amount", "value", "net amount", "transaction amount")
|
||||
debit_col = find("debit", "debit amount", "payments out", "money out", "withdrawal")
|
||||
credit_col = find("credit", "credit amount", "payments in", "money in", "deposit")
|
||||
bal_col = find("balance", "running balance")
|
||||
ref_col = find("reference", "notes", "tags", "category")
|
||||
|
||||
if not date_col:
|
||||
date_col = list(norm.values())[0] if norm else "date"
|
||||
if not desc_col:
|
||||
desc_col = list(norm.values())[1] if len(norm) > 1 else "description"
|
||||
|
||||
if debit_col and credit_col:
|
||||
return CsvMapping(
|
||||
date=date_col,
|
||||
description=desc_col,
|
||||
debit=debit_col,
|
||||
credit=credit_col,
|
||||
balance=bal_col,
|
||||
reference=ref_col,
|
||||
detected_format=None,
|
||||
)
|
||||
|
||||
return CsvMapping(
|
||||
date=date_col,
|
||||
description=desc_col,
|
||||
amount=amt_col or (list(norm.values())[2] if len(norm) > 2 else "amount"),
|
||||
balance=bal_col,
|
||||
reference=ref_col,
|
||||
detected_format=None,
|
||||
)
|
||||
|
||||
|
||||
def parse_csv_content(content: bytes) -> tuple[list[str], list[dict]]:
|
||||
"""Decode and return (headers, rows)."""
|
||||
for enc in ("utf-8-sig", "utf-8", "latin-1"):
|
||||
try:
|
||||
text = content.decode(enc)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Cannot decode file — try saving as UTF-8")
|
||||
|
||||
# Some bank exports (Lloyds, Barclays) include preamble lines before the header
|
||||
lines = text.splitlines()
|
||||
header_idx = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "," in line and len(line.split(",")) >= 2:
|
||||
header_idx = i
|
||||
break
|
||||
|
||||
cleaned = "\n".join(lines[header_idx:])
|
||||
reader = csv.DictReader(io.StringIO(cleaned))
|
||||
headers = [h.strip() for h in (reader.fieldnames or []) if h and h.strip()]
|
||||
rows = []
|
||||
for row in reader:
|
||||
clean_row = {k.strip(): (v.strip() if v else "") for k, v in row.items() if k and k.strip()}
|
||||
if any(clean_row.values()):
|
||||
rows.append(clean_row)
|
||||
|
||||
return headers, rows
|
||||
300
backend/app/services/investment_service.py
Normal file
300
backend/app/services/investment_service.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
import uuid
|
||||
from datetime import date, datetime, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.asset import Asset
|
||||
from app.db.models.asset_price import AssetPrice
|
||||
from app.db.models.investment_holding import InvestmentHolding
|
||||
from app.db.models.investment_transaction import InvestmentTransaction
|
||||
from app.schemas.investment import (
|
||||
HoldingCreate,
|
||||
HoldingResponse,
|
||||
InvestmentTxnCreate,
|
||||
PerformanceMetrics,
|
||||
PortfolioSummary,
|
||||
)
|
||||
|
||||
|
||||
async def _get_asset(db: AsyncSession, asset_id: uuid.UUID) -> Asset | None:
|
||||
result = await db.execute(select(Asset).where(Asset.id == asset_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def _holding_to_response(holding: InvestmentHolding, asset: Asset) -> HoldingResponse:
|
||||
cost_basis_total = holding.quantity * holding.avg_cost_basis
|
||||
current_price = asset.last_price
|
||||
current_value = holding.quantity * current_price if current_price else None
|
||||
unrealised_gain = (current_value - cost_basis_total) if current_value is not None else None
|
||||
unrealised_gain_pct = None
|
||||
if unrealised_gain is not None and cost_basis_total > 0:
|
||||
unrealised_gain_pct = (unrealised_gain / cost_basis_total * 100).quantize(Decimal("0.01"))
|
||||
|
||||
return HoldingResponse(
|
||||
id=holding.id,
|
||||
account_id=holding.account_id,
|
||||
asset_id=holding.asset_id,
|
||||
symbol=asset.symbol,
|
||||
asset_name=asset.name,
|
||||
asset_type=asset.type,
|
||||
quantity=holding.quantity,
|
||||
avg_cost_basis=holding.avg_cost_basis,
|
||||
current_price=current_price,
|
||||
current_value=current_value,
|
||||
cost_basis_total=cost_basis_total,
|
||||
unrealised_gain=unrealised_gain,
|
||||
unrealised_gain_pct=unrealised_gain_pct,
|
||||
currency=holding.currency,
|
||||
price_change_24h=asset.price_change_24h,
|
||||
)
|
||||
|
||||
|
||||
async def get_portfolio(db: AsyncSession, user_id: uuid.UUID) -> PortfolioSummary:
|
||||
result = await db.execute(
|
||||
select(InvestmentHolding).where(
|
||||
InvestmentHolding.user_id == user_id,
|
||||
InvestmentHolding.quantity > 0,
|
||||
)
|
||||
)
|
||||
holdings = result.scalars().all()
|
||||
|
||||
responses = []
|
||||
total_value = Decimal("0")
|
||||
total_cost = Decimal("0")
|
||||
|
||||
for h in holdings:
|
||||
asset = await _get_asset(db, h.asset_id)
|
||||
if not asset:
|
||||
continue
|
||||
r = _holding_to_response(h, asset)
|
||||
responses.append(r)
|
||||
total_cost += r.cost_basis_total
|
||||
if r.current_value is not None:
|
||||
total_value += r.current_value
|
||||
|
||||
total_gain = total_value - total_cost
|
||||
total_gain_pct = (total_gain / total_cost * 100).quantize(Decimal("0.01")) if total_cost > 0 else Decimal("0")
|
||||
|
||||
return PortfolioSummary(
|
||||
total_value=total_value,
|
||||
total_cost=total_cost,
|
||||
total_gain=total_gain,
|
||||
total_gain_pct=total_gain_pct,
|
||||
currency="GBP",
|
||||
holdings=responses,
|
||||
)
|
||||
|
||||
|
||||
async def get_holding(db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID) -> InvestmentHolding | None:
|
||||
result = await db.execute(
|
||||
select(InvestmentHolding).where(
|
||||
InvestmentHolding.id == holding_id,
|
||||
InvestmentHolding.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def create_holding(db: AsyncSession, user_id: uuid.UUID, data: HoldingCreate) -> InvestmentHolding:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Check if holding already exists for this account+asset
|
||||
result = await db.execute(
|
||||
select(InvestmentHolding).where(
|
||||
InvestmentHolding.user_id == user_id,
|
||||
InvestmentHolding.account_id == data.account_id,
|
||||
InvestmentHolding.asset_id == data.asset_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
holding = InvestmentHolding(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
account_id=data.account_id,
|
||||
asset_id=data.asset_id,
|
||||
quantity=data.quantity,
|
||||
avg_cost_basis=data.avg_cost_basis,
|
||||
currency=data.currency,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(holding)
|
||||
await db.flush()
|
||||
await db.refresh(holding)
|
||||
return holding
|
||||
|
||||
|
||||
async def add_investment_transaction(
|
||||
db: AsyncSession, user_id: uuid.UUID, data: InvestmentTxnCreate
|
||||
) -> InvestmentTransaction:
|
||||
holding = await get_holding(db, user_id, data.holding_id)
|
||||
if not holding:
|
||||
raise ValueError("Holding not found")
|
||||
|
||||
total = data.quantity * data.price + data.fees
|
||||
|
||||
txn = InvestmentTransaction(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
holding_id=data.holding_id,
|
||||
type=data.type,
|
||||
quantity=data.quantity,
|
||||
price=data.price,
|
||||
fees=data.fees,
|
||||
total_amount=total,
|
||||
currency=data.currency,
|
||||
date=data.date,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(txn)
|
||||
|
||||
# Update holding quantity and avg cost basis
|
||||
if data.type == "buy" or data.type == "transfer_in":
|
||||
new_qty = holding.quantity + data.quantity
|
||||
if new_qty > 0:
|
||||
holding.avg_cost_basis = (
|
||||
(holding.quantity * holding.avg_cost_basis + data.quantity * data.price)
|
||||
/ new_qty
|
||||
)
|
||||
holding.quantity = new_qty
|
||||
elif data.type == "sell" or data.type == "transfer_out":
|
||||
holding.quantity = max(Decimal("0"), holding.quantity - data.quantity)
|
||||
elif data.type == "split":
|
||||
if data.price > 0:
|
||||
holding.quantity = holding.quantity * data.quantity
|
||||
holding.avg_cost_basis = holding.avg_cost_basis / data.quantity
|
||||
# dividend and fee don't affect quantity/cost basis
|
||||
|
||||
holding.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
await db.refresh(txn)
|
||||
return txn
|
||||
|
||||
|
||||
async def list_investment_transactions(
|
||||
db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID
|
||||
) -> list[InvestmentTransaction]:
|
||||
result = await db.execute(
|
||||
select(InvestmentTransaction)
|
||||
.where(
|
||||
InvestmentTransaction.user_id == user_id,
|
||||
InvestmentTransaction.holding_id == holding_id,
|
||||
)
|
||||
.order_by(InvestmentTransaction.date.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_performance(db: AsyncSession, user_id: uuid.UUID) -> PerformanceMetrics:
|
||||
portfolio = await get_portfolio(db, user_id)
|
||||
total_return = portfolio.total_gain
|
||||
total_return_pct = portfolio.total_gain_pct
|
||||
return PerformanceMetrics(
|
||||
twrr=None, # full TWRR requires snapshot history — placeholder
|
||||
total_return=total_return,
|
||||
total_return_pct=total_return_pct,
|
||||
currency="GBP",
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_asset(
|
||||
db: AsyncSession, symbol: str, name: str, asset_type: str,
|
||||
currency: str, data_source: str, data_source_id: str | None,
|
||||
exchange: str | None = None,
|
||||
) -> Asset:
|
||||
result = await db.execute(
|
||||
select(Asset).where(Asset.symbol == symbol.upper(), Asset.data_source == data_source)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
asset = Asset(
|
||||
id=uuid.uuid4(),
|
||||
symbol=symbol.upper(),
|
||||
name=name,
|
||||
type=asset_type,
|
||||
currency=currency,
|
||||
exchange=exchange,
|
||||
data_source=data_source,
|
||||
data_source_id=data_source_id,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(asset)
|
||||
await db.flush()
|
||||
await db.refresh(asset)
|
||||
return asset
|
||||
|
||||
|
||||
async def update_asset_price(
|
||||
db: AsyncSession, asset: Asset, price: Decimal, change_24h: Decimal | None
|
||||
) -> None:
|
||||
asset.last_price = price
|
||||
asset.price_change_24h = change_24h
|
||||
asset.last_price_at = datetime.now(timezone.utc)
|
||||
asset.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def upsert_price_history(db: AsyncSession, asset_id: uuid.UUID, rows: list[dict]) -> int:
|
||||
count = 0
|
||||
for row in rows:
|
||||
result = await db.execute(
|
||||
select(AssetPrice).where(AssetPrice.asset_id == asset_id, AssetPrice.date == row["date"])
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
existing.open = row["open"]
|
||||
existing.high = row["high"]
|
||||
existing.low = row["low"]
|
||||
existing.close = row["close"]
|
||||
existing.volume = row["volume"]
|
||||
else:
|
||||
db.add(AssetPrice(
|
||||
id=uuid.uuid4(),
|
||||
asset_id=asset_id,
|
||||
date=row["date"],
|
||||
open=row.get("open"),
|
||||
high=row.get("high"),
|
||||
low=row.get("low"),
|
||||
close=row["close"],
|
||||
volume=row.get("volume"),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
))
|
||||
count += 1
|
||||
await db.flush()
|
||||
return count
|
||||
|
||||
|
||||
async def get_price_history(
|
||||
db: AsyncSession, asset_id: uuid.UUID, days: int = 365
|
||||
) -> list[AssetPrice]:
|
||||
from datetime import timedelta
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(AssetPrice)
|
||||
.where(AssetPrice.asset_id == asset_id, AssetPrice.date >= cutoff)
|
||||
.order_by(AssetPrice.date.asc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def search_assets(db: AsyncSession, query: str) -> list[Asset]:
|
||||
from sqlalchemy import or_, func
|
||||
q = query.strip().upper()
|
||||
result = await db.execute(
|
||||
select(Asset).where(
|
||||
or_(
|
||||
func.upper(Asset.symbol).contains(q),
|
||||
func.upper(Asset.name).contains(q),
|
||||
)
|
||||
).limit(10)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
116
backend/app/services/price_feed_service.py
Normal file
116
backend/app/services/price_feed_service.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
Live price fetching: yfinance for stocks/ETFs, CoinGecko for crypto.
|
||||
Falls back gracefully — never raises, always returns None on failure.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import date, datetime, timezone, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def _run_sync(fn, *args):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, fn, *args)
|
||||
|
||||
|
||||
def _fetch_yahoo(symbol: str) -> dict | None:
|
||||
try:
|
||||
import yfinance as yf
|
||||
ticker = yf.Ticker(symbol)
|
||||
info = ticker.fast_info
|
||||
price = getattr(info, "last_price", None) or getattr(info, "regularMarketPrice", None)
|
||||
prev = getattr(info, "previous_close", None)
|
||||
if price is None:
|
||||
return None
|
||||
change_24h = None
|
||||
if prev and prev > 0:
|
||||
change_24h = round((price - prev) / prev * 100, 4)
|
||||
return {
|
||||
"price": Decimal(str(round(price, 8))),
|
||||
"change_24h": Decimal(str(change_24h)) if change_24h is not None else None,
|
||||
"currency": (getattr(info, "currency", None) or "USD").upper(),
|
||||
"name": getattr(info, "long_name", None) or symbol,
|
||||
"exchange": getattr(info, "exchange", None),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("yahoo_fetch_failed", symbol=symbol, error=str(exc))
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_coingecko(coin_id: str) -> dict | None:
|
||||
try:
|
||||
import requests
|
||||
r = requests.get(
|
||||
f"https://api.coingecko.com/api/v3/simple/price",
|
||||
params={"ids": coin_id, "vs_currencies": "usd,gbp", "include_24hr_change": "true"},
|
||||
timeout=10,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json().get(coin_id, {})
|
||||
if not data:
|
||||
return None
|
||||
return {
|
||||
"price": Decimal(str(data.get("gbp", data.get("usd", 0)))),
|
||||
"change_24h": Decimal(str(round(data.get("gbp_24h_change", 0), 4))),
|
||||
"currency": "GBP",
|
||||
"name": coin_id,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("coingecko_fetch_failed", coin_id=coin_id, error=str(exc))
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_yahoo_history(symbol: str, days: int = 365) -> list[dict]:
|
||||
try:
|
||||
import yfinance as yf
|
||||
ticker = yf.Ticker(symbol)
|
||||
hist = ticker.history(period=f"{days}d", interval="1d")
|
||||
rows = []
|
||||
for ts, row in hist.iterrows():
|
||||
rows.append({
|
||||
"date": ts.date(),
|
||||
"open": Decimal(str(round(float(row["Open"]), 8))),
|
||||
"high": Decimal(str(round(float(row["High"]), 8))),
|
||||
"low": Decimal(str(round(float(row["Low"]), 8))),
|
||||
"close": Decimal(str(round(float(row["Close"]), 8))),
|
||||
"volume": Decimal(str(int(row.get("Volume", 0) or 0))),
|
||||
})
|
||||
return rows
|
||||
except Exception as exc:
|
||||
logger.warning("yahoo_history_failed", symbol=symbol, error=str(exc))
|
||||
return []
|
||||
|
||||
|
||||
async def fetch_price(symbol: str, data_source: str, data_source_id: str | None) -> dict | None:
|
||||
if data_source == "coingecko":
|
||||
return await _run_sync(_fetch_coingecko, data_source_id or symbol.lower())
|
||||
return await _run_sync(_fetch_yahoo, symbol)
|
||||
|
||||
|
||||
async def fetch_history(symbol: str, days: int = 365) -> list[dict]:
|
||||
return await _run_sync(_fetch_yahoo_history, symbol, days)
|
||||
|
||||
|
||||
def search_yahoo(query: str) -> list[dict]:
|
||||
try:
|
||||
import yfinance as yf
|
||||
ticker = yf.Ticker(query)
|
||||
info = ticker.fast_info
|
||||
price = getattr(info, "last_price", None)
|
||||
if price:
|
||||
return [{
|
||||
"symbol": query.upper(),
|
||||
"name": getattr(info, "long_name", None) or query.upper(),
|
||||
"type": "stock",
|
||||
"currency": (getattr(info, "currency", None) or "USD").upper(),
|
||||
"exchange": getattr(info, "exchange", None),
|
||||
"data_source": "yahoo_finance",
|
||||
"data_source_id": None,
|
||||
}]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
356
backend/app/services/report_service.py
Normal file
356
backend/app/services/report_service.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
import uuid
|
||||
from datetime import date, datetime, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy import and_, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.account import Account
|
||||
from app.db.models.budget import Budget
|
||||
from app.db.models.category import Category
|
||||
from app.db.models.net_worth_snapshot import NetWorthSnapshot
|
||||
from app.db.models.transaction import Transaction
|
||||
from app.schemas.report import (
|
||||
BudgetVsActualItem,
|
||||
BudgetVsActualReport,
|
||||
CashFlowPoint,
|
||||
CashFlowReport,
|
||||
CategoryBreakdownItem,
|
||||
CategoryBreakdownReport,
|
||||
IncomeExpensePoint,
|
||||
IncomeExpenseReport,
|
||||
NetWorthPoint,
|
||||
NetWorthReport,
|
||||
SpendingTrendPoint,
|
||||
SpendingTrendsReport,
|
||||
)
|
||||
|
||||
LIABILITY_TYPES = {"credit_card", "loan", "mortgage"}
|
||||
|
||||
|
||||
async def _current_net_worth(db: AsyncSession, user_id: uuid.UUID) -> tuple[Decimal, Decimal]:
|
||||
result = await db.execute(
|
||||
select(Account).where(
|
||||
Account.user_id == user_id,
|
||||
Account.include_in_net_worth == True, # noqa: E712
|
||||
Account.is_active == True, # noqa: E712
|
||||
Account.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
assets = Decimal("0")
|
||||
liabilities = Decimal("0")
|
||||
for acc in accounts:
|
||||
bal = acc.current_balance or Decimal("0")
|
||||
if acc.type in LIABILITY_TYPES:
|
||||
liabilities += bal
|
||||
else:
|
||||
assets += bal
|
||||
return assets, liabilities
|
||||
|
||||
|
||||
async def get_net_worth_report(
|
||||
db: AsyncSession, user_id: uuid.UUID, base_currency: str, months: int = 12
|
||||
) -> NetWorthReport:
|
||||
cutoff = date.today() - relativedelta(months=months)
|
||||
result = await db.execute(
|
||||
select(NetWorthSnapshot)
|
||||
.where(NetWorthSnapshot.user_id == user_id, NetWorthSnapshot.date >= cutoff)
|
||||
.order_by(NetWorthSnapshot.date.asc())
|
||||
)
|
||||
snapshots = result.scalars().all()
|
||||
|
||||
points = [
|
||||
NetWorthPoint(
|
||||
date=s.date,
|
||||
total_assets=s.total_assets,
|
||||
total_liabilities=s.total_liabilities,
|
||||
net_worth=s.net_worth,
|
||||
base_currency=s.base_currency,
|
||||
)
|
||||
for s in snapshots
|
||||
]
|
||||
|
||||
assets, liabilities = await _current_net_worth(db, user_id)
|
||||
current_nw = assets - liabilities
|
||||
|
||||
change_30d = Decimal("0")
|
||||
change_30d_pct = Decimal("0")
|
||||
if points:
|
||||
past_nw = points[0].net_worth
|
||||
change_30d = current_nw - past_nw
|
||||
if past_nw != 0:
|
||||
change_30d_pct = (change_30d / abs(past_nw) * 100).quantize(Decimal("0.01"))
|
||||
|
||||
return NetWorthReport(
|
||||
points=points,
|
||||
current_net_worth=current_nw,
|
||||
change_30d=change_30d,
|
||||
change_30d_pct=change_30d_pct,
|
||||
base_currency=base_currency,
|
||||
)
|
||||
|
||||
|
||||
async def get_income_expense_report(
|
||||
db: AsyncSession, user_id: uuid.UUID, months: int = 12
|
||||
) -> IncomeExpenseReport:
|
||||
cutoff = (date.today().replace(day=1) - relativedelta(months=months - 1))
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
TO_CHAR(date, 'YYYY-MM') AS month,
|
||||
SUM(CASE WHEN type = 'income' THEN amount ELSE 0 END) AS income,
|
||||
SUM(CASE WHEN type = 'expense' THEN ABS(amount) ELSE 0 END) AS expenses
|
||||
FROM transactions
|
||||
WHERE user_id = CAST(:uid AS uuid)
|
||||
AND status != 'void'
|
||||
AND deleted_at IS NULL
|
||||
AND date >= :cutoff
|
||||
GROUP BY TO_CHAR(date, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
""").bindparams(uid=str(user_id), cutoff=cutoff)
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
points = []
|
||||
total_income = Decimal("0")
|
||||
total_expenses = Decimal("0")
|
||||
for row in rows:
|
||||
inc = Decimal(str(row.income or 0))
|
||||
exp = Decimal(str(row.expenses or 0))
|
||||
points.append(IncomeExpensePoint(month=row.month, income=inc, expenses=exp, net=inc - exp))
|
||||
total_income += inc
|
||||
total_expenses += exp
|
||||
|
||||
n = len(points) or 1
|
||||
return IncomeExpenseReport(
|
||||
points=points,
|
||||
total_income=total_income,
|
||||
total_expenses=total_expenses,
|
||||
avg_monthly_income=(total_income / n).quantize(Decimal("0.01")),
|
||||
avg_monthly_expenses=(total_expenses / n).quantize(Decimal("0.01")),
|
||||
currency="GBP",
|
||||
)
|
||||
|
||||
|
||||
async def get_cash_flow_report(
|
||||
db: AsyncSession, user_id: uuid.UUID, date_from: date, date_to: date
|
||||
) -> CashFlowReport:
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
date,
|
||||
SUM(CASE WHEN amount > 0 THEN amount ELSE 0 END) AS inflow,
|
||||
SUM(CASE WHEN amount < 0 THEN ABS(amount) ELSE 0 END) AS outflow
|
||||
FROM transactions
|
||||
WHERE user_id = CAST(:uid AS uuid)
|
||||
AND status != 'void'
|
||||
AND deleted_at IS NULL
|
||||
AND date BETWEEN :df AND :dt
|
||||
AND type IN ('income', 'expense')
|
||||
GROUP BY date
|
||||
ORDER BY date ASC
|
||||
""").bindparams(uid=str(user_id), df=date_from, dt=date_to)
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
points = []
|
||||
running = Decimal("0")
|
||||
total_inflow = Decimal("0")
|
||||
total_outflow = Decimal("0")
|
||||
for row in rows:
|
||||
inflow = Decimal(str(row.inflow or 0))
|
||||
outflow = Decimal(str(row.outflow or 0))
|
||||
running += inflow - outflow
|
||||
total_inflow += inflow
|
||||
total_outflow += outflow
|
||||
points.append(
|
||||
CashFlowPoint(
|
||||
date=row.date,
|
||||
inflow=inflow,
|
||||
outflow=outflow,
|
||||
net=inflow - outflow,
|
||||
running_balance=running,
|
||||
)
|
||||
)
|
||||
|
||||
return CashFlowReport(
|
||||
points=points,
|
||||
total_inflow=total_inflow,
|
||||
total_outflow=total_outflow,
|
||||
currency="GBP",
|
||||
)
|
||||
|
||||
|
||||
async def get_category_breakdown(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
date_from: date,
|
||||
date_to: date,
|
||||
txn_type: str = "expense",
|
||||
) -> CategoryBreakdownReport:
|
||||
result = await db.execute(
|
||||
select(
|
||||
Transaction.category_id,
|
||||
func.sum(func.abs(Transaction.amount)).label("total"),
|
||||
func.count(Transaction.id).label("cnt"),
|
||||
)
|
||||
.where(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.type == txn_type,
|
||||
Transaction.status != "void",
|
||||
Transaction.date >= date_from,
|
||||
Transaction.date <= date_to,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
.group_by(Transaction.category_id)
|
||||
.order_by(func.sum(func.abs(Transaction.amount)).desc())
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
grand_total = Decimal("0")
|
||||
raw = []
|
||||
for row in rows:
|
||||
amt = Decimal(str(row.total or 0))
|
||||
grand_total += amt
|
||||
if row.category_id:
|
||||
cat_result = await db.execute(select(Category).where(Category.id == row.category_id))
|
||||
category = cat_result.scalar_one_or_none()
|
||||
cat_name = category.name if category else "Uncategorised"
|
||||
else:
|
||||
cat_name = "Uncategorised"
|
||||
raw.append((row.category_id, cat_name, amt, row.cnt))
|
||||
|
||||
items = [
|
||||
CategoryBreakdownItem(
|
||||
category_id=str(cat_id) if cat_id else None,
|
||||
category_name=name,
|
||||
amount=amt,
|
||||
percent=(amt / grand_total * 100).quantize(Decimal("0.01")) if grand_total > 0 else Decimal("0"),
|
||||
transaction_count=cnt,
|
||||
)
|
||||
for cat_id, name, amt, cnt in raw
|
||||
]
|
||||
|
||||
return CategoryBreakdownReport(
|
||||
items=items,
|
||||
total=grand_total,
|
||||
currency="GBP",
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
)
|
||||
|
||||
|
||||
async def get_budget_vs_actual(db: AsyncSession, user_id: uuid.UUID) -> BudgetVsActualReport:
|
||||
from app.services.budget_service import list_budgets, _period_bounds
|
||||
today = date.today()
|
||||
budgets = await list_budgets(db, user_id, active_only=True)
|
||||
|
||||
items = []
|
||||
total_budgeted = Decimal("0")
|
||||
total_actual = Decimal("0")
|
||||
|
||||
for budget in budgets:
|
||||
period_start, period_end = _period_bounds(budget.period, today)
|
||||
cat_result = await db.execute(select(Category).where(Category.id == budget.category_id))
|
||||
category = cat_result.scalar_one_or_none()
|
||||
cat_name = category.name if category else "Unknown"
|
||||
|
||||
spent_result = await db.execute(
|
||||
select(func.coalesce(func.sum(func.abs(Transaction.amount)), Decimal("0")))
|
||||
.where(
|
||||
and_(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.category_id == budget.category_id,
|
||||
Transaction.type == "expense",
|
||||
Transaction.status != "void",
|
||||
Transaction.date >= period_start,
|
||||
Transaction.date <= period_end,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
actual = Decimal(str(spent_result.scalar() or 0))
|
||||
variance = budget.amount - actual
|
||||
pct = (actual / budget.amount * 100).quantize(Decimal("0.01")) if budget.amount > 0 else Decimal("0")
|
||||
|
||||
items.append(
|
||||
BudgetVsActualItem(
|
||||
budget_id=str(budget.id),
|
||||
budget_name=budget.name,
|
||||
category_name=cat_name,
|
||||
budgeted=budget.amount,
|
||||
actual=actual,
|
||||
variance=variance,
|
||||
percent_used=pct,
|
||||
)
|
||||
)
|
||||
total_budgeted += budget.amount
|
||||
total_actual += actual
|
||||
|
||||
return BudgetVsActualReport(
|
||||
items=items,
|
||||
total_budgeted=total_budgeted,
|
||||
total_actual=total_actual,
|
||||
currency="GBP",
|
||||
)
|
||||
|
||||
|
||||
async def get_spending_trends(
|
||||
db: AsyncSession, user_id: uuid.UUID, months: int = 6
|
||||
) -> SpendingTrendsReport:
|
||||
cutoff = (date.today().replace(day=1) - relativedelta(months=months - 1))
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
TO_CHAR(t.date, 'YYYY-MM') AS month,
|
||||
COALESCE(c.name, 'Uncategorised') AS category_name,
|
||||
SUM(ABS(t.amount)) AS amount
|
||||
FROM transactions t
|
||||
LEFT JOIN categories c ON c.id = t.category_id
|
||||
WHERE t.user_id = CAST(:uid AS uuid)
|
||||
AND t.type = 'expense'
|
||||
AND t.status != 'void'
|
||||
AND t.deleted_at IS NULL
|
||||
AND t.date >= :cutoff
|
||||
GROUP BY TO_CHAR(t.date, 'YYYY-MM'), c.name
|
||||
ORDER BY month ASC, amount DESC
|
||||
""").bindparams(uid=str(user_id), cutoff=cutoff)
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
points = [
|
||||
SpendingTrendPoint(month=row.month, category_name=row.category_name, amount=Decimal(str(row.amount or 0)))
|
||||
for row in rows
|
||||
]
|
||||
categories = list(dict.fromkeys(p.category_name for p in points))
|
||||
|
||||
return SpendingTrendsReport(points=points, categories=categories, currency="GBP")
|
||||
|
||||
|
||||
async def take_net_worth_snapshot(db: AsyncSession, user_id: uuid.UUID, base_currency: str) -> None:
|
||||
today = date.today()
|
||||
existing = await db.execute(
|
||||
select(NetWorthSnapshot).where(
|
||||
NetWorthSnapshot.user_id == user_id,
|
||||
NetWorthSnapshot.date == today,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
return
|
||||
|
||||
assets, liabilities = await _current_net_worth(db, user_id)
|
||||
snapshot = NetWorthSnapshot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
total_assets=assets,
|
||||
total_liabilities=liabilities,
|
||||
net_worth=assets - liabilities,
|
||||
base_currency=base_currency,
|
||||
breakdown={},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(snapshot)
|
||||
await db.flush()
|
||||
308
backend/app/services/transaction_service.py
Normal file
308
backend/app/services/transaction_service.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import decrypt_field, encrypt_field
|
||||
from app.db.models.transaction import Transaction
|
||||
from app.schemas.transaction import TransactionCreate, TransactionFilter, TransactionUpdate
|
||||
from app.services.account_service import recalculate_balance
|
||||
|
||||
|
||||
class TransactionError(Exception):
|
||||
def __init__(self, detail: str, status_code: int = 400):
|
||||
self.detail = detail
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def _enc(v: str | None) -> bytes | None:
|
||||
return encrypt_field(v) if v else None
|
||||
|
||||
|
||||
def _dec(v: bytes | None) -> str | None:
|
||||
return decrypt_field(v) if v else None
|
||||
|
||||
|
||||
def _to_response(t: Transaction) -> dict:
|
||||
return {
|
||||
"id": t.id,
|
||||
"account_id": t.account_id,
|
||||
"transfer_account_id": t.transfer_account_id,
|
||||
"category_id": t.category_id,
|
||||
"type": t.type,
|
||||
"status": t.status,
|
||||
"amount": t.amount,
|
||||
"amount_base": t.amount_base,
|
||||
"currency": t.currency,
|
||||
"base_currency": t.base_currency,
|
||||
"exchange_rate": t.exchange_rate,
|
||||
"date": t.date,
|
||||
"description": _dec(t.description_enc) or "",
|
||||
"merchant": _dec(t.merchant_enc),
|
||||
"notes": _dec(t.notes_enc),
|
||||
"tags": t.tags or [],
|
||||
"is_recurring": t.is_recurring,
|
||||
"created_at": t.created_at,
|
||||
"updated_at": t.updated_at,
|
||||
}
|
||||
|
||||
|
||||
async def create_transaction(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
data: TransactionCreate,
|
||||
base_currency: str,
|
||||
) -> dict:
|
||||
now = datetime.now(timezone.utc)
|
||||
amount = data.amount
|
||||
|
||||
# For transfers, create mirrored entry on destination account
|
||||
txn = Transaction(
|
||||
user_id=user_id,
|
||||
account_id=data.account_id,
|
||||
transfer_account_id=data.transfer_account_id,
|
||||
category_id=data.category_id,
|
||||
type=data.type,
|
||||
status=data.status,
|
||||
amount=amount,
|
||||
amount_base=amount, # Phase 3: convert via FX rate
|
||||
currency=data.currency,
|
||||
base_currency=base_currency,
|
||||
exchange_rate=Decimal("1") if data.currency == base_currency else None,
|
||||
date=data.date,
|
||||
description_enc=encrypt_field(data.description),
|
||||
merchant_enc=_enc(data.merchant),
|
||||
notes_enc=_enc(data.notes),
|
||||
tags=data.tags,
|
||||
is_recurring=data.is_recurring,
|
||||
recurring_rule=data.recurring_rule,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(txn)
|
||||
await db.flush()
|
||||
|
||||
# If transfer, create the counter-entry on the destination account
|
||||
if data.type == "transfer" and data.transfer_account_id:
|
||||
counter = Transaction(
|
||||
user_id=user_id,
|
||||
account_id=data.transfer_account_id,
|
||||
transfer_account_id=data.account_id,
|
||||
category_id=data.category_id,
|
||||
type="transfer",
|
||||
status=data.status,
|
||||
amount=-amount, # opposite sign
|
||||
amount_base=-amount,
|
||||
currency=data.currency,
|
||||
base_currency=base_currency,
|
||||
exchange_rate=Decimal("1") if data.currency == base_currency else None,
|
||||
date=data.date,
|
||||
description_enc=encrypt_field(data.description),
|
||||
merchant_enc=_enc(data.merchant),
|
||||
notes_enc=_enc(data.notes),
|
||||
tags=data.tags,
|
||||
is_recurring=False,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(counter)
|
||||
await db.flush()
|
||||
await recalculate_balance(db, data.transfer_account_id)
|
||||
|
||||
await recalculate_balance(db, data.account_id)
|
||||
return _to_response(txn)
|
||||
|
||||
|
||||
async def list_transactions(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
filters: TransactionFilter,
|
||||
) -> dict:
|
||||
conditions = [
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
]
|
||||
|
||||
if filters.account_id:
|
||||
conditions.append(Transaction.account_id == filters.account_id)
|
||||
if filters.category_id:
|
||||
conditions.append(Transaction.category_id == filters.category_id)
|
||||
if filters.type:
|
||||
conditions.append(Transaction.type == filters.type)
|
||||
if filters.status:
|
||||
conditions.append(Transaction.status == filters.status)
|
||||
if filters.date_from:
|
||||
conditions.append(Transaction.date >= filters.date_from)
|
||||
if filters.date_to:
|
||||
conditions.append(Transaction.date <= filters.date_to)
|
||||
if filters.min_amount is not None:
|
||||
conditions.append(Transaction.amount >= filters.min_amount)
|
||||
if filters.max_amount is not None:
|
||||
conditions.append(Transaction.amount <= filters.max_amount)
|
||||
|
||||
query = select(Transaction).where(and_(*conditions)).order_by(Transaction.date.desc(), Transaction.created_at.desc())
|
||||
|
||||
# Count total
|
||||
from sqlalchemy import func
|
||||
count_result = await db.execute(select(func.count()).select_from(query.subquery()))
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Paginate
|
||||
offset = (filters.page - 1) * filters.page_size
|
||||
query = query.offset(offset).limit(filters.page_size)
|
||||
result = await db.execute(query)
|
||||
items = [_to_response(t) for t in result.scalars()]
|
||||
|
||||
# Filter by search (post-decrypt — Phase 3 will add FTS)
|
||||
if filters.search:
|
||||
term = filters.search.lower()
|
||||
items = [
|
||||
t for t in items
|
||||
if term in t["description"].lower()
|
||||
or (t["merchant"] and term in t["merchant"].lower())
|
||||
]
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": filters.page,
|
||||
"page_size": filters.page_size,
|
||||
"pages": max(1, -(-total // filters.page_size)),
|
||||
}
|
||||
|
||||
|
||||
async def get_transaction(db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID) -> Transaction:
|
||||
result = await db.execute(
|
||||
select(Transaction).where(
|
||||
Transaction.id == txn_id,
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise TransactionError("Transaction not found", status_code=404)
|
||||
return txn
|
||||
|
||||
|
||||
async def update_transaction(
|
||||
db: AsyncSession,
|
||||
txn_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: TransactionUpdate,
|
||||
base_currency: str,
|
||||
) -> dict:
|
||||
txn = await get_transaction(db, txn_id, user_id)
|
||||
now = datetime.now(timezone.utc)
|
||||
old_account_id = txn.account_id
|
||||
|
||||
if data.category_id is not None:
|
||||
txn.category_id = data.category_id
|
||||
if data.status is not None:
|
||||
txn.status = data.status
|
||||
if data.amount is not None:
|
||||
txn.amount = data.amount
|
||||
txn.amount_base = data.amount
|
||||
if data.date is not None:
|
||||
txn.date = data.date
|
||||
if data.description is not None:
|
||||
txn.description_enc = encrypt_field(data.description)
|
||||
if data.merchant is not None:
|
||||
txn.merchant_enc = _enc(data.merchant)
|
||||
if data.notes is not None:
|
||||
txn.notes_enc = _enc(data.notes)
|
||||
if data.tags is not None:
|
||||
txn.tags = data.tags
|
||||
|
||||
txn.updated_at = now
|
||||
await db.flush()
|
||||
await recalculate_balance(db, old_account_id)
|
||||
return _to_response(txn)
|
||||
|
||||
|
||||
async def delete_transaction(db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
txn = await get_transaction(db, txn_id, user_id)
|
||||
account_id = txn.account_id
|
||||
txn.deleted_at = datetime.now(timezone.utc)
|
||||
txn.updated_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
await recalculate_balance(db, account_id)
|
||||
|
||||
|
||||
async def import_csv(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
account_id: uuid.UUID,
|
||||
rows: list[dict],
|
||||
base_currency: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Import transactions from parsed CSV rows.
|
||||
Each row must have: date, description, amount
|
||||
Optional: merchant, notes, category_name
|
||||
Returns counts of imported vs skipped (duplicates).
|
||||
"""
|
||||
imported = 0
|
||||
skipped = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for row in rows:
|
||||
# Build dedup hash from date + description + amount
|
||||
raw = f"{row['date']}|{row['description']}|{row['amount']}"
|
||||
import_hash = hashlib.sha256(raw.encode()).hexdigest()
|
||||
|
||||
# Check duplicate
|
||||
exists = await db.scalar(
|
||||
select(Transaction.id).where(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.import_hash == import_hash,
|
||||
)
|
||||
)
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
amount = Decimal(str(row["amount"]))
|
||||
from datetime import date as date_type
|
||||
import dateutil.parser
|
||||
txn_date = dateutil.parser.parse(str(row["date"])).date()
|
||||
except Exception:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
txn_type = "income" if amount > 0 else "expense"
|
||||
|
||||
txn = Transaction(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
type=txn_type,
|
||||
status="cleared",
|
||||
amount=amount,
|
||||
amount_base=amount,
|
||||
currency=row.get("currency", base_currency),
|
||||
base_currency=base_currency,
|
||||
exchange_rate=Decimal("1"),
|
||||
date=txn_date,
|
||||
description_enc=encrypt_field(str(row.get("description", ""))),
|
||||
merchant_enc=_enc(row.get("merchant")),
|
||||
notes_enc=_enc(row.get("notes")),
|
||||
tags=[],
|
||||
is_recurring=False,
|
||||
import_hash=import_hash,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(txn)
|
||||
imported += 1
|
||||
|
||||
await db.flush()
|
||||
if imported > 0:
|
||||
await recalculate_balance(db, account_id)
|
||||
|
||||
return {"imported": imported, "skipped": skipped}
|
||||
0
backend/app/workers/__init__.py
Normal file
0
backend/app/workers/__init__.py
Normal file
74
backend/app/workers/fx_sync.py
Normal file
74
backend/app/workers/fx_sync.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
PAIRS = [
|
||||
("GBP", "USD"), ("GBP", "EUR"), ("GBP", "JPY"), ("GBP", "CAD"),
|
||||
("GBP", "AUD"), ("GBP", "CHF"), ("USD", "GBP"), ("EUR", "GBP"),
|
||||
]
|
||||
|
||||
|
||||
async def fx_sync_job() -> None:
|
||||
from app.dependencies import get_session_factory
|
||||
session_factory = get_session_factory()
|
||||
if not session_factory:
|
||||
return
|
||||
|
||||
try:
|
||||
import requests
|
||||
r = requests.get(
|
||||
"https://api.exchangerate-api.com/v4/latest/GBP",
|
||||
timeout=10,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
rates = data.get("rates", {})
|
||||
except Exception as exc:
|
||||
logger.error("fx_fetch_failed", error=str(exc))
|
||||
return
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
import uuid as _uuid
|
||||
from sqlalchemy import select
|
||||
from app.db.models.currency import ExchangeRate # type: ignore[attr-defined]
|
||||
|
||||
async with session_factory() as db:
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
for base, quote in PAIRS:
|
||||
if base == "GBP":
|
||||
rate_val = rates.get(quote)
|
||||
else:
|
||||
gbp_to_base = rates.get(base)
|
||||
if not gbp_to_base or gbp_to_base == 0:
|
||||
continue
|
||||
rate_val = 1 / gbp_to_base
|
||||
|
||||
if not rate_val:
|
||||
continue
|
||||
|
||||
result = await db.execute(
|
||||
select(ExchangeRate).where(
|
||||
ExchangeRate.base_currency == base,
|
||||
ExchangeRate.quote_currency == quote,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
existing.rate = Decimal(str(round(rate_val, 8)))
|
||||
existing.fetched_at = now
|
||||
else:
|
||||
db.add(ExchangeRate(
|
||||
id=_uuid.uuid4(),
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
rate=Decimal(str(round(rate_val, 8))),
|
||||
source="exchangerate-api",
|
||||
fetched_at=now,
|
||||
))
|
||||
await db.commit()
|
||||
logger.info("fx_sync_done", pairs=len(PAIRS))
|
||||
except Exception as exc:
|
||||
await db.rollback()
|
||||
logger.error("fx_sync_db_failed", error=str(exc))
|
||||
31
backend/app/workers/price_sync.py
Normal file
31
backend/app/workers/price_sync.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import structlog
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def price_sync_job() -> None:
|
||||
from app.dependencies import get_session_factory
|
||||
from app.db.models.asset import Asset
|
||||
from app.services.price_feed_service import fetch_price
|
||||
from app.services.investment_service import update_asset_price
|
||||
|
||||
session_factory = get_session_factory()
|
||||
if not session_factory:
|
||||
return
|
||||
|
||||
async with session_factory() as db:
|
||||
try:
|
||||
result = await db.execute(select(Asset).where(Asset.is_active == True)) # noqa: E712
|
||||
assets = result.scalars().all()
|
||||
updated = 0
|
||||
for asset in assets:
|
||||
data = await fetch_price(asset.symbol, asset.data_source, asset.data_source_id)
|
||||
if data and data.get("price"):
|
||||
await update_asset_price(db, asset, data["price"], data.get("change_24h"))
|
||||
updated += 1
|
||||
await db.commit()
|
||||
logger.info("price_sync_done", updated=updated, total=len(assets))
|
||||
except Exception as exc:
|
||||
await db.rollback()
|
||||
logger.error("price_sync_failed", error=str(exc))
|
||||
33
backend/app/workers/scheduler.py
Normal file
33
backend/app/workers/scheduler.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""
|
||||
APScheduler background jobs. Starts with the FastAPI lifespan.
|
||||
"""
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
_scheduler: AsyncIOScheduler | None = None
|
||||
|
||||
|
||||
async def start_scheduler() -> None:
|
||||
global _scheduler
|
||||
from app.workers.snapshot import snapshot_job
|
||||
from app.workers.price_sync import price_sync_job
|
||||
from app.workers.fx_sync import fx_sync_job
|
||||
_scheduler = AsyncIOScheduler()
|
||||
|
||||
_scheduler.add_job(snapshot_job, CronTrigger(hour=2, minute=0), id="nw_snapshot")
|
||||
_scheduler.add_job(price_sync_job, CronTrigger(minute="*/15"), id="price_sync")
|
||||
_scheduler.add_job(fx_sync_job, CronTrigger(minute=0), id="fx_sync")
|
||||
# _scheduler.add_job(backup_job, CronTrigger(hour=3), id="backup")
|
||||
# _scheduler.add_job(ml_retrain_job, CronTrigger(day_of_week="sun", hour=1), id="ml_retrain")
|
||||
|
||||
_scheduler.start()
|
||||
logger.info("scheduler_started")
|
||||
|
||||
|
||||
async def stop_scheduler() -> None:
|
||||
if _scheduler and _scheduler.running:
|
||||
_scheduler.shutdown(wait=False)
|
||||
logger.info("scheduler_stopped")
|
||||
23
backend/app/workers/snapshot.py
Normal file
23
backend/app/workers/snapshot.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import structlog
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def snapshot_job() -> None:
|
||||
from app.dependencies import get_session_factory
|
||||
from app.db.models.user import User
|
||||
from app.services.report_service import take_net_worth_snapshot
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as db:
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.deleted_at.is_(None)))
|
||||
users = result.scalars().all()
|
||||
for user in users:
|
||||
await take_net_worth_snapshot(db, user.id, user.base_currency)
|
||||
await db.commit()
|
||||
logger.info("snapshot_job_done", users=len(users))
|
||||
except Exception as exc:
|
||||
await db.rollback()
|
||||
logger.error("snapshot_job_failed", error=str(exc))
|
||||
Loading…
Add table
Add a link
Reference in a new issue