from __future__ import annotations import uuid from datetime import datetime, timezone from decimal import Decimal from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import encrypt_field, decrypt_field from app.db.models.account import Account from app.db.models.currency import ExchangeRate from app.db.models.transaction import Transaction from app.schemas.account import AccountCreate, AccountUpdate # Account types that are liabilities (balance is negative contribution to net worth) LIABILITY_TYPES = {"credit_card", "loan", "mortgage"} class AccountError(Exception): def __init__(self, detail: str, status_code: int = 400): self.detail = detail self.status_code = status_code def _encrypt(value: str | None) -> bytes | None: if value is None: return None return encrypt_field(value) def _decrypt(data: bytes | None) -> str | None: if not data: return None return decrypt_field(data) def _to_response(account: Account) -> dict: return { "id": account.id, "name": _decrypt(account.name_enc) or "", "institution": _decrypt(account.institution_enc), "type": account.type, "currency": account.currency, "current_balance": account.current_balance, "credit_limit": account.credit_limit, "interest_rate": account.interest_rate, "is_active": account.is_active, "include_in_net_worth": account.include_in_net_worth, "color": account.color, "icon": account.icon, "notes": _decrypt(account.notes_enc), "created_at": account.created_at, "updated_at": account.updated_at, } async def create_account( db: AsyncSession, user_id: uuid.UUID, data: AccountCreate, ) -> dict: now = datetime.now(timezone.utc) account = Account( user_id=user_id, name_enc=encrypt_field(data.name), institution_enc=_encrypt(data.institution), type=data.type, currency=data.currency, current_balance=data.opening_balance, credit_limit=data.credit_limit, interest_rate=data.interest_rate, include_in_net_worth=data.include_in_net_worth, color=data.color, icon=data.icon, notes_enc=_encrypt(data.notes), created_at=now, updated_at=now, ) db.add(account) await db.flush() return _to_response(account) async def list_accounts(db: AsyncSession, user_id: uuid.UUID) -> list[dict]: result = await db.execute( select(Account).where( Account.user_id == user_id, Account.deleted_at.is_(None), ).order_by(Account.created_at) ) return [_to_response(a) for a in result.scalars()] async def get_account(db: AsyncSession, account_id: uuid.UUID, user_id: uuid.UUID) -> Account: result = await db.execute( select(Account).where( Account.id == account_id, Account.user_id == user_id, Account.deleted_at.is_(None), ) ) account = result.scalar_one_or_none() if not account: raise AccountError("Account not found", status_code=404) return account async def update_account( db: AsyncSession, account_id: uuid.UUID, user_id: uuid.UUID, data: AccountUpdate, ) -> dict: account = await get_account(db, account_id, user_id) now = datetime.now(timezone.utc) if data.name is not None: account.name_enc = encrypt_field(data.name) if data.institution is not None: account.institution_enc = _encrypt(data.institution) if data.opening_balance is not None: account.current_balance = data.opening_balance if data.credit_limit is not None: account.credit_limit = data.credit_limit if data.interest_rate is not None: account.interest_rate = data.interest_rate if data.include_in_net_worth is not None: account.include_in_net_worth = data.include_in_net_worth if data.is_active is not None: account.is_active = data.is_active if data.color is not None: account.color = data.color if data.icon is not None: account.icon = data.icon if data.notes is not None: account.notes_enc = _encrypt(data.notes) account.updated_at = now await db.flush() return _to_response(account) async def delete_account( db: AsyncSession, account_id: uuid.UUID, user_id: uuid.UUID, ) -> None: account = await get_account(db, account_id, user_id) account.deleted_at = datetime.now(timezone.utc) account.updated_at = datetime.now(timezone.utc) await db.flush() async def recalculate_balance(db: AsyncSession, account_id: uuid.UUID) -> None: """Recompute current_balance from all non-deleted transactions.""" result = await db.execute( select(func.sum(Transaction.amount)).where( Transaction.account_id == account_id, Transaction.deleted_at.is_(None), ) ) total = result.scalar_one_or_none() or Decimal("0") account = await db.get(Account, account_id) if account: account.current_balance = total account.updated_at = datetime.now(timezone.utc) await db.flush() async def _fx_rate(db: AsyncSession, from_currency: str, to_currency: str) -> Decimal: if from_currency == to_currency: return Decimal("1") result = await db.execute( select(ExchangeRate) .where(ExchangeRate.base_currency == from_currency, ExchangeRate.quote_currency == to_currency) .order_by(ExchangeRate.fetched_at.desc()) .limit(1) ) er = result.scalar_one_or_none() return er.rate if er else Decimal("1") async def get_net_worth(db: AsyncSession, user_id: uuid.UUID, base_currency: str) -> dict: accounts = await db.execute( select(Account).where( Account.user_id == user_id, Account.include_in_net_worth == True, Account.deleted_at.is_(None), ) ) total_assets = Decimal("0") total_liabilities = Decimal("0") for account in accounts.scalars(): bal = account.current_balance or Decimal("0") acct_currency = account.currency or base_currency if acct_currency != base_currency: rate = await _fx_rate(db, acct_currency, base_currency) bal = bal * rate if account.type in LIABILITY_TYPES: total_liabilities += abs(bal) else: total_assets += bal return { "total_assets": total_assets, "total_liabilities": total_liabilities, "net_worth": total_assets - total_liabilities, "base_currency": base_currency, }