MyMidas/backend/app/services/investment_service.py
megaproxy 312594f3d2 Include investment holding values in net worth and account balances
- Net worth report, balance sheet, and daily snapshots now add holding
  market values (falling back to cost basis) to investment-type account
  balances (investment, pension, stocks_shares_isa, crypto_wallet)
- Accounts list shows total value for investment accounts with a
  breakdown line ("£X cash + £Y holdings") when both are non-zero
- Add Holding modal gains a "Debit account for this purchase" toggle
  that creates a matching withdrawal transaction, enabling proper cash
  flow tracking for users who fund their brokerage via transfer first
- Both simple (holdings-only) and full cash-flow workflows produce
  correct net worth figures without double-counting

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-23 10:10:19 +00:00

513 lines
18 KiB
Python

import uuid
from datetime import date, datetime, timezone
from decimal import Decimal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.asset import Asset
from app.db.models.asset_price import AssetPrice
from app.db.models.investment_holding import InvestmentHolding
from app.db.models.investment_transaction import InvestmentTransaction
from app.schemas.investment import (
CapitalGainsDisposal,
CapitalGainsReport,
HoldingCreate,
HoldingResponse,
InvestmentTxnCreate,
PerformanceMetrics,
PortfolioSummary,
TaxYearSummary,
)
async def _get_asset(db: AsyncSession, asset_id: uuid.UUID) -> Asset | None:
result = await db.execute(select(Asset).where(Asset.id == asset_id))
return result.scalar_one_or_none()
async def _fetch_fx_rate(db: AsyncSession, from_currency: str, to_currency: str) -> Decimal:
if from_currency == to_currency:
return Decimal("1")
from app.db.models.currency import ExchangeRate
result = await db.execute(
select(ExchangeRate)
.where(ExchangeRate.base_currency == from_currency, ExchangeRate.quote_currency == to_currency)
.order_by(ExchangeRate.fetched_at.desc())
.limit(1)
)
er = result.scalar_one_or_none()
return er.rate if er else Decimal("1")
def _holding_to_response(
holding: InvestmentHolding,
asset: Asset,
fx_rates: dict[tuple[str, str], Decimal] | None = None,
) -> HoldingResponse:
fx_rates = fx_rates or {}
cost_basis_total = holding.quantity * holding.avg_cost_basis
# Convert asset's last_price to the holding's currency so P&L is comparable
current_price_native = asset.last_price
if current_price_native is not None and asset.currency != holding.currency:
rate = fx_rates.get((asset.currency, holding.currency), Decimal("1"))
current_price = current_price_native * rate
else:
current_price = current_price_native
current_value = holding.quantity * current_price if current_price is not None else None
unrealised_gain = (current_value - cost_basis_total) if current_value is not None else None
unrealised_gain_pct = None
if unrealised_gain is not None and cost_basis_total > 0:
unrealised_gain_pct = (unrealised_gain / cost_basis_total * 100).quantize(Decimal("0.01"))
return HoldingResponse(
id=holding.id,
account_id=holding.account_id,
asset_id=holding.asset_id,
symbol=asset.symbol,
asset_name=asset.name,
asset_type=asset.type,
quantity=holding.quantity,
avg_cost_basis=holding.avg_cost_basis,
current_price=current_price,
current_value=current_value,
cost_basis_total=cost_basis_total,
unrealised_gain=unrealised_gain,
unrealised_gain_pct=unrealised_gain_pct,
currency=holding.currency,
price_change_24h=asset.price_change_24h,
)
async def get_portfolio(db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP") -> PortfolioSummary:
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.user_id == user_id,
InvestmentHolding.quantity > 0,
)
)
holdings = result.scalars().all()
# Pre-fetch all assets and determine which FX pairs we need
assets: dict[uuid.UUID, Asset] = {}
for h in holdings:
if h.asset_id not in assets:
asset = await _get_asset(db, h.asset_id)
if asset:
assets[h.asset_id] = asset
pairs_needed: set[tuple[str, str]] = set()
for h in holdings:
asset = assets.get(h.asset_id)
if not asset:
continue
if asset.currency != h.currency:
pairs_needed.add((asset.currency, h.currency))
if h.currency != base_currency:
pairs_needed.add((h.currency, base_currency))
fx_rates: dict[tuple[str, str], Decimal] = {}
for from_curr, to_curr in pairs_needed:
fx_rates[(from_curr, to_curr)] = await _fetch_fx_rate(db, from_curr, to_curr)
responses = []
total_value = Decimal("0")
total_cost = Decimal("0")
for h in holdings:
asset = assets.get(h.asset_id)
if not asset:
continue
r = _holding_to_response(h, asset, fx_rates)
responses.append(r)
# Convert each holding to base_currency for the portfolio totals
to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1")
total_cost += r.cost_basis_total * to_base
if r.current_value is not None:
total_value += r.current_value * to_base
total_gain = total_value - total_cost
total_gain_pct = (total_gain / total_cost * 100).quantize(Decimal("0.01")) if total_cost > 0 else Decimal("0")
return PortfolioSummary(
total_value=total_value,
total_cost=total_cost,
total_gain=total_gain,
total_gain_pct=total_gain_pct,
currency=base_currency,
holdings=responses,
)
async def get_portfolio_value_by_account(
db: AsyncSession, user_id: uuid.UUID, base_currency: str
) -> dict[uuid.UUID, Decimal]:
"""Return total holding value (in base_currency) keyed by account_id."""
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.user_id == user_id,
InvestmentHolding.quantity > 0,
)
)
holdings = result.scalars().all()
assets: dict[uuid.UUID, Asset] = {}
for h in holdings:
if h.asset_id not in assets:
asset = await _get_asset(db, h.asset_id)
if asset:
assets[h.asset_id] = asset
pairs_needed: set[tuple[str, str]] = set()
for h in holdings:
asset = assets.get(h.asset_id)
if not asset:
continue
if asset.currency != h.currency:
pairs_needed.add((asset.currency, h.currency))
if h.currency != base_currency:
pairs_needed.add((h.currency, base_currency))
fx_rates: dict[tuple[str, str], Decimal] = {}
for from_curr, to_curr in pairs_needed:
fx_rates[(from_curr, to_curr)] = await _fetch_fx_rate(db, from_curr, to_curr)
totals: dict[uuid.UUID, Decimal] = {}
for h in holdings:
asset = assets.get(h.asset_id)
if not asset:
continue
r = _holding_to_response(h, asset, fx_rates)
value = r.current_value if r.current_value is not None else r.cost_basis_total
to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1")
totals[h.account_id] = totals.get(h.account_id, Decimal("0")) + value * to_base
return totals
async def get_holding(db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID) -> InvestmentHolding | None:
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.id == holding_id,
InvestmentHolding.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def create_holding(db: AsyncSession, user_id: uuid.UUID, data: HoldingCreate) -> InvestmentHolding:
now = datetime.now(timezone.utc)
# Check if holding already exists for this account+asset
result = await db.execute(
select(InvestmentHolding).where(
InvestmentHolding.user_id == user_id,
InvestmentHolding.account_id == data.account_id,
InvestmentHolding.asset_id == data.asset_id,
)
)
existing = result.scalar_one_or_none()
if existing:
return existing
holding = InvestmentHolding(
id=uuid.uuid4(),
user_id=user_id,
account_id=data.account_id,
asset_id=data.asset_id,
quantity=data.quantity,
avg_cost_basis=data.avg_cost_basis,
currency=data.currency,
created_at=now,
updated_at=now,
)
db.add(holding)
await db.flush()
await db.refresh(holding)
return holding
async def add_investment_transaction(
db: AsyncSession, user_id: uuid.UUID, data: InvestmentTxnCreate
) -> InvestmentTransaction:
holding = await get_holding(db, user_id, data.holding_id)
if not holding:
raise ValueError("Holding not found")
total = data.quantity * data.price + data.fees
txn = InvestmentTransaction(
id=uuid.uuid4(),
user_id=user_id,
holding_id=data.holding_id,
type=data.type,
quantity=data.quantity,
price=data.price,
fees=data.fees,
total_amount=total,
currency=data.currency,
date=data.date,
created_at=datetime.now(timezone.utc),
)
db.add(txn)
# Update holding quantity and avg cost basis
if data.type == "buy" or data.type == "transfer_in":
new_qty = holding.quantity + data.quantity
if new_qty > 0:
holding.avg_cost_basis = (
(holding.quantity * holding.avg_cost_basis + data.quantity * data.price)
/ new_qty
)
holding.quantity = new_qty
elif data.type == "sell" or data.type == "transfer_out":
holding.quantity = max(Decimal("0"), holding.quantity - data.quantity)
elif data.type == "split":
if data.price > 0:
holding.quantity = holding.quantity * data.quantity
holding.avg_cost_basis = holding.avg_cost_basis / data.quantity
# dividend and fee don't affect quantity/cost basis
holding.updated_at = datetime.now(timezone.utc)
await db.flush()
await db.refresh(txn)
return txn
async def list_investment_transactions(
db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID
) -> list[InvestmentTransaction]:
result = await db.execute(
select(InvestmentTransaction)
.where(
InvestmentTransaction.user_id == user_id,
InvestmentTransaction.holding_id == holding_id,
)
.order_by(InvestmentTransaction.date.desc())
)
return list(result.scalars().all())
async def get_performance(db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP") -> PerformanceMetrics:
portfolio = await get_portfolio(db, user_id, base_currency)
total_return = portfolio.total_gain
total_return_pct = portfolio.total_gain_pct
return PerformanceMetrics(
twrr=None, # full TWRR requires snapshot history — placeholder
total_return=total_return,
total_return_pct=total_return_pct,
currency=base_currency,
)
def _uk_tax_year(d: date) -> str:
"""Return the UK tax year string for a given date (e.g. '2024/25')."""
if d >= date(d.year, 4, 6):
return f"{d.year}/{str(d.year + 1)[2:]}"
return f"{d.year - 1}/{str(d.year)[2:]}"
async def get_capital_gains(
db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP"
) -> CapitalGainsReport:
"""
Compute capital gains using the UK Section 104 pool method.
Each asset's transactions are replayed chronologically; on each sell
the cost of disposal is (sold_qty / pool_qty) * pool_cost.
All values are converted to base_currency using current FX rates.
"""
holdings_result = await db.execute(
select(InvestmentHolding).where(InvestmentHolding.user_id == user_id)
)
holdings = holdings_result.scalars().all()
# Pre-fetch assets and FX rates
assets: dict[uuid.UUID, Asset] = {}
holding_currencies: set[str] = set()
for h in holdings:
if h.asset_id not in assets:
a = await _get_asset(db, h.asset_id)
if a:
assets[h.asset_id] = a
holding_currencies.add(h.currency)
fx_rates: dict[tuple[str, str], Decimal] = {}
for curr in holding_currencies:
if curr != base_currency:
fx_rates[(curr, base_currency)] = await _fetch_fx_rate(db, curr, base_currency)
disposals_by_year: dict[str, list[CapitalGainsDisposal]] = {}
for h in holdings:
asset = assets.get(h.asset_id)
if not asset:
continue
txns_result = await db.execute(
select(InvestmentTransaction)
.where(InvestmentTransaction.holding_id == h.id)
.order_by(InvestmentTransaction.date.asc(), InvestmentTransaction.created_at.asc())
)
txns = txns_result.scalars().all()
pool_qty = Decimal("0")
pool_cost = Decimal("0") # in holding.currency
for txn in txns:
if txn.type in ("buy", "transfer_in"):
cost_of_purchase = txn.quantity * txn.price + txn.fees
pool_qty += txn.quantity
pool_cost += cost_of_purchase
elif txn.type in ("sell", "transfer_out") and pool_qty > 0:
sell_qty = min(txn.quantity, pool_qty)
cost_per_unit = pool_cost / pool_qty
cost_of_disposal = cost_per_unit * sell_qty
proceeds = txn.price * sell_qty - txn.fees
# Convert to base_currency
to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1")
proceeds_base = (proceeds * to_base).quantize(Decimal("0.01"))
cost_base = (cost_of_disposal * to_base).quantize(Decimal("0.01"))
gain_base = proceeds_base - cost_base
tax_year = _uk_tax_year(txn.date)
disposals_by_year.setdefault(tax_year, []).append(
CapitalGainsDisposal(
date=txn.date,
symbol=asset.symbol,
asset_name=asset.name,
quantity=sell_qty,
proceeds=proceeds_base,
cost=cost_base,
gain=gain_base,
currency=base_currency,
)
)
pool_qty -= sell_qty
pool_cost -= cost_of_disposal
if pool_qty <= 0:
pool_qty = Decimal("0")
pool_cost = Decimal("0")
elif txn.type == "split" and txn.price > 0:
pool_qty = pool_qty * txn.quantity
# pool_cost stays the same; avg cost per unit changes
tax_years: list[TaxYearSummary] = []
for year_label in sorted(disposals_by_year.keys(), reverse=True):
year_disposals = sorted(disposals_by_year[year_label], key=lambda d: d.date)
total_proceeds = sum(d.proceeds for d in year_disposals)
total_cost = sum(d.cost for d in year_disposals)
total_gain = total_proceeds - total_cost
tax_years.append(TaxYearSummary(
tax_year=year_label,
disposals=year_disposals,
total_proceeds=total_proceeds,
total_cost=total_cost,
total_gain=total_gain,
currency=base_currency,
))
return CapitalGainsReport(tax_years=tax_years, currency=base_currency)
async def get_or_create_asset(
db: AsyncSession, symbol: str, name: str, asset_type: str,
currency: str, data_source: str, data_source_id: str | None,
exchange: str | None = None,
) -> Asset:
result = await db.execute(
select(Asset).where(Asset.symbol == symbol.upper(), Asset.data_source == data_source)
)
existing = result.scalar_one_or_none()
if existing:
return existing
now = datetime.now(timezone.utc)
asset = Asset(
id=uuid.uuid4(),
symbol=symbol.upper(),
name=name,
type=asset_type,
currency=currency,
exchange=exchange,
data_source=data_source,
data_source_id=data_source_id,
is_active=True,
created_at=now,
updated_at=now,
)
db.add(asset)
await db.flush()
await db.refresh(asset)
return asset
async def update_asset_price(
db: AsyncSession, asset: Asset, price: Decimal, change_24h: Decimal | None
) -> None:
asset.last_price = price
asset.price_change_24h = change_24h
asset.last_price_at = datetime.now(timezone.utc)
asset.updated_at = datetime.now(timezone.utc)
await db.flush()
async def upsert_price_history(db: AsyncSession, asset_id: uuid.UUID, rows: list[dict]) -> int:
count = 0
for row in rows:
result = await db.execute(
select(AssetPrice).where(AssetPrice.asset_id == asset_id, AssetPrice.date == row["date"])
)
existing = result.scalar_one_or_none()
if existing:
existing.open = row["open"]
existing.high = row["high"]
existing.low = row["low"]
existing.close = row["close"]
existing.volume = row["volume"]
else:
db.add(AssetPrice(
id=uuid.uuid4(),
asset_id=asset_id,
date=row["date"],
open=row.get("open"),
high=row.get("high"),
low=row.get("low"),
close=row["close"],
volume=row.get("volume"),
created_at=datetime.now(timezone.utc),
))
count += 1
await db.flush()
return count
async def get_price_history(
db: AsyncSession, asset_id: uuid.UUID, days: int = 365
) -> list[AssetPrice]:
from datetime import timedelta
cutoff = date.today() - timedelta(days=days)
result = await db.execute(
select(AssetPrice)
.where(AssetPrice.asset_id == asset_id, AssetPrice.date >= cutoff)
.order_by(AssetPrice.date.asc())
)
return list(result.scalars().all())
async def search_assets(db: AsyncSession, query: str) -> list[Asset]:
from sqlalchemy import or_, func
q = query.strip().upper()
result = await db.execute(
select(Asset).where(
or_(
func.upper(Asset.symbol).contains(q),
func.upper(Asset.name).contains(q),
)
).limit(10)
)
return list(result.scalars().all())