import uuid from datetime import date, datetime, timezone from decimal import Decimal from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models.asset import Asset from app.db.models.asset_price import AssetPrice from app.db.models.investment_holding import InvestmentHolding from app.db.models.investment_transaction import InvestmentTransaction from app.schemas.investment import ( CapitalGainsDisposal, CapitalGainsReport, HoldingCreate, HoldingResponse, InvestmentTxnCreate, PerformanceMetrics, PortfolioSummary, TaxYearSummary, ) async def _get_asset(db: AsyncSession, asset_id: uuid.UUID) -> Asset | None: result = await db.execute(select(Asset).where(Asset.id == asset_id)) return result.scalar_one_or_none() async def _fetch_fx_rate(db: AsyncSession, from_currency: str, to_currency: str) -> Decimal: if from_currency == to_currency: return Decimal("1") from app.db.models.currency import ExchangeRate result = await db.execute( select(ExchangeRate) .where(ExchangeRate.base_currency == from_currency, ExchangeRate.quote_currency == to_currency) .order_by(ExchangeRate.fetched_at.desc()) .limit(1) ) er = result.scalar_one_or_none() return er.rate if er else Decimal("1") def _holding_to_response( holding: InvestmentHolding, asset: Asset, fx_rates: dict[tuple[str, str], Decimal] | None = None, ) -> HoldingResponse: fx_rates = fx_rates or {} cost_basis_total = holding.quantity * holding.avg_cost_basis # Convert asset's last_price to the holding's currency so P&L is comparable current_price_native = asset.last_price if current_price_native is not None and asset.currency != holding.currency: rate = fx_rates.get((asset.currency, holding.currency), Decimal("1")) current_price = current_price_native * rate else: current_price = current_price_native current_value = holding.quantity * current_price if current_price is not None else None unrealised_gain = (current_value - cost_basis_total) if current_value is not None else None unrealised_gain_pct = None if unrealised_gain is not None and cost_basis_total > 0: unrealised_gain_pct = (unrealised_gain / cost_basis_total * 100).quantize(Decimal("0.01")) return HoldingResponse( id=holding.id, account_id=holding.account_id, asset_id=holding.asset_id, symbol=asset.symbol, asset_name=asset.name, asset_type=asset.type, quantity=holding.quantity, avg_cost_basis=holding.avg_cost_basis, current_price=current_price, current_value=current_value, cost_basis_total=cost_basis_total, unrealised_gain=unrealised_gain, unrealised_gain_pct=unrealised_gain_pct, currency=holding.currency, price_change_24h=asset.price_change_24h, ) async def get_portfolio(db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP") -> PortfolioSummary: result = await db.execute( select(InvestmentHolding).where( InvestmentHolding.user_id == user_id, InvestmentHolding.quantity > 0, ) ) holdings = result.scalars().all() # Pre-fetch all assets and determine which FX pairs we need assets: dict[uuid.UUID, Asset] = {} for h in holdings: if h.asset_id not in assets: asset = await _get_asset(db, h.asset_id) if asset: assets[h.asset_id] = asset pairs_needed: set[tuple[str, str]] = set() for h in holdings: asset = assets.get(h.asset_id) if not asset: continue if asset.currency != h.currency: pairs_needed.add((asset.currency, h.currency)) if h.currency != base_currency: pairs_needed.add((h.currency, base_currency)) fx_rates: dict[tuple[str, str], Decimal] = {} for from_curr, to_curr in pairs_needed: fx_rates[(from_curr, to_curr)] = await _fetch_fx_rate(db, from_curr, to_curr) responses = [] total_value = Decimal("0") total_cost = Decimal("0") for h in holdings: asset = assets.get(h.asset_id) if not asset: continue r = _holding_to_response(h, asset, fx_rates) responses.append(r) # Convert each holding to base_currency for the portfolio totals to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1") total_cost += r.cost_basis_total * to_base if r.current_value is not None: total_value += r.current_value * to_base total_gain = total_value - total_cost total_gain_pct = (total_gain / total_cost * 100).quantize(Decimal("0.01")) if total_cost > 0 else Decimal("0") return PortfolioSummary( total_value=total_value, total_cost=total_cost, total_gain=total_gain, total_gain_pct=total_gain_pct, currency=base_currency, holdings=responses, ) async def get_holding(db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID) -> InvestmentHolding | None: result = await db.execute( select(InvestmentHolding).where( InvestmentHolding.id == holding_id, InvestmentHolding.user_id == user_id, ) ) return result.scalar_one_or_none() async def create_holding(db: AsyncSession, user_id: uuid.UUID, data: HoldingCreate) -> InvestmentHolding: now = datetime.now(timezone.utc) # Check if holding already exists for this account+asset result = await db.execute( select(InvestmentHolding).where( InvestmentHolding.user_id == user_id, InvestmentHolding.account_id == data.account_id, InvestmentHolding.asset_id == data.asset_id, ) ) existing = result.scalar_one_or_none() if existing: return existing holding = InvestmentHolding( id=uuid.uuid4(), user_id=user_id, account_id=data.account_id, asset_id=data.asset_id, quantity=data.quantity, avg_cost_basis=data.avg_cost_basis, currency=data.currency, created_at=now, updated_at=now, ) db.add(holding) await db.flush() await db.refresh(holding) return holding async def add_investment_transaction( db: AsyncSession, user_id: uuid.UUID, data: InvestmentTxnCreate ) -> InvestmentTransaction: holding = await get_holding(db, user_id, data.holding_id) if not holding: raise ValueError("Holding not found") total = data.quantity * data.price + data.fees txn = InvestmentTransaction( id=uuid.uuid4(), user_id=user_id, holding_id=data.holding_id, type=data.type, quantity=data.quantity, price=data.price, fees=data.fees, total_amount=total, currency=data.currency, date=data.date, created_at=datetime.now(timezone.utc), ) db.add(txn) # Update holding quantity and avg cost basis if data.type == "buy" or data.type == "transfer_in": new_qty = holding.quantity + data.quantity if new_qty > 0: holding.avg_cost_basis = ( (holding.quantity * holding.avg_cost_basis + data.quantity * data.price) / new_qty ) holding.quantity = new_qty elif data.type == "sell" or data.type == "transfer_out": holding.quantity = max(Decimal("0"), holding.quantity - data.quantity) elif data.type == "split": if data.price > 0: holding.quantity = holding.quantity * data.quantity holding.avg_cost_basis = holding.avg_cost_basis / data.quantity # dividend and fee don't affect quantity/cost basis holding.updated_at = datetime.now(timezone.utc) await db.flush() await db.refresh(txn) return txn async def list_investment_transactions( db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID ) -> list[InvestmentTransaction]: result = await db.execute( select(InvestmentTransaction) .where( InvestmentTransaction.user_id == user_id, InvestmentTransaction.holding_id == holding_id, ) .order_by(InvestmentTransaction.date.desc()) ) return list(result.scalars().all()) async def get_performance(db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP") -> PerformanceMetrics: portfolio = await get_portfolio(db, user_id, base_currency) total_return = portfolio.total_gain total_return_pct = portfolio.total_gain_pct return PerformanceMetrics( twrr=None, # full TWRR requires snapshot history — placeholder total_return=total_return, total_return_pct=total_return_pct, currency=base_currency, ) def _uk_tax_year(d: date) -> str: """Return the UK tax year string for a given date (e.g. '2024/25').""" if d >= date(d.year, 4, 6): return f"{d.year}/{str(d.year + 1)[2:]}" return f"{d.year - 1}/{str(d.year)[2:]}" async def get_capital_gains( db: AsyncSession, user_id: uuid.UUID, base_currency: str = "GBP" ) -> CapitalGainsReport: """ Compute capital gains using the UK Section 104 pool method. Each asset's transactions are replayed chronologically; on each sell the cost of disposal is (sold_qty / pool_qty) * pool_cost. All values are converted to base_currency using current FX rates. """ holdings_result = await db.execute( select(InvestmentHolding).where(InvestmentHolding.user_id == user_id) ) holdings = holdings_result.scalars().all() # Pre-fetch assets and FX rates assets: dict[uuid.UUID, Asset] = {} holding_currencies: set[str] = set() for h in holdings: if h.asset_id not in assets: a = await _get_asset(db, h.asset_id) if a: assets[h.asset_id] = a holding_currencies.add(h.currency) fx_rates: dict[tuple[str, str], Decimal] = {} for curr in holding_currencies: if curr != base_currency: fx_rates[(curr, base_currency)] = await _fetch_fx_rate(db, curr, base_currency) disposals_by_year: dict[str, list[CapitalGainsDisposal]] = {} for h in holdings: asset = assets.get(h.asset_id) if not asset: continue txns_result = await db.execute( select(InvestmentTransaction) .where(InvestmentTransaction.holding_id == h.id) .order_by(InvestmentTransaction.date.asc(), InvestmentTransaction.created_at.asc()) ) txns = txns_result.scalars().all() pool_qty = Decimal("0") pool_cost = Decimal("0") # in holding.currency for txn in txns: if txn.type in ("buy", "transfer_in"): cost_of_purchase = txn.quantity * txn.price + txn.fees pool_qty += txn.quantity pool_cost += cost_of_purchase elif txn.type in ("sell", "transfer_out") and pool_qty > 0: sell_qty = min(txn.quantity, pool_qty) cost_per_unit = pool_cost / pool_qty cost_of_disposal = cost_per_unit * sell_qty proceeds = txn.price * sell_qty - txn.fees # Convert to base_currency to_base = fx_rates.get((h.currency, base_currency), Decimal("1")) if h.currency != base_currency else Decimal("1") proceeds_base = (proceeds * to_base).quantize(Decimal("0.01")) cost_base = (cost_of_disposal * to_base).quantize(Decimal("0.01")) gain_base = proceeds_base - cost_base tax_year = _uk_tax_year(txn.date) disposals_by_year.setdefault(tax_year, []).append( CapitalGainsDisposal( date=txn.date, symbol=asset.symbol, asset_name=asset.name, quantity=sell_qty, proceeds=proceeds_base, cost=cost_base, gain=gain_base, currency=base_currency, ) ) pool_qty -= sell_qty pool_cost -= cost_of_disposal if pool_qty <= 0: pool_qty = Decimal("0") pool_cost = Decimal("0") elif txn.type == "split" and txn.price > 0: pool_qty = pool_qty * txn.quantity # pool_cost stays the same; avg cost per unit changes tax_years: list[TaxYearSummary] = [] for year_label in sorted(disposals_by_year.keys(), reverse=True): year_disposals = sorted(disposals_by_year[year_label], key=lambda d: d.date) total_proceeds = sum(d.proceeds for d in year_disposals) total_cost = sum(d.cost for d in year_disposals) total_gain = total_proceeds - total_cost tax_years.append(TaxYearSummary( tax_year=year_label, disposals=year_disposals, total_proceeds=total_proceeds, total_cost=total_cost, total_gain=total_gain, currency=base_currency, )) return CapitalGainsReport(tax_years=tax_years, currency=base_currency) async def get_or_create_asset( db: AsyncSession, symbol: str, name: str, asset_type: str, currency: str, data_source: str, data_source_id: str | None, exchange: str | None = None, ) -> Asset: result = await db.execute( select(Asset).where(Asset.symbol == symbol.upper(), Asset.data_source == data_source) ) existing = result.scalar_one_or_none() if existing: return existing now = datetime.now(timezone.utc) asset = Asset( id=uuid.uuid4(), symbol=symbol.upper(), name=name, type=asset_type, currency=currency, exchange=exchange, data_source=data_source, data_source_id=data_source_id, is_active=True, created_at=now, updated_at=now, ) db.add(asset) await db.flush() await db.refresh(asset) return asset async def update_asset_price( db: AsyncSession, asset: Asset, price: Decimal, change_24h: Decimal | None ) -> None: asset.last_price = price asset.price_change_24h = change_24h asset.last_price_at = datetime.now(timezone.utc) asset.updated_at = datetime.now(timezone.utc) await db.flush() async def upsert_price_history(db: AsyncSession, asset_id: uuid.UUID, rows: list[dict]) -> int: count = 0 for row in rows: result = await db.execute( select(AssetPrice).where(AssetPrice.asset_id == asset_id, AssetPrice.date == row["date"]) ) existing = result.scalar_one_or_none() if existing: existing.open = row["open"] existing.high = row["high"] existing.low = row["low"] existing.close = row["close"] existing.volume = row["volume"] else: db.add(AssetPrice( id=uuid.uuid4(), asset_id=asset_id, date=row["date"], open=row.get("open"), high=row.get("high"), low=row.get("low"), close=row["close"], volume=row.get("volume"), created_at=datetime.now(timezone.utc), )) count += 1 await db.flush() return count async def get_price_history( db: AsyncSession, asset_id: uuid.UUID, days: int = 365 ) -> list[AssetPrice]: from datetime import timedelta cutoff = date.today() - timedelta(days=days) result = await db.execute( select(AssetPrice) .where(AssetPrice.asset_id == asset_id, AssetPrice.date >= cutoff) .order_by(AssetPrice.date.asc()) ) return list(result.scalars().all()) async def search_assets(db: AsyncSession, query: str) -> list[Asset]: from sqlalchemy import or_, func q = query.strip().upper() result = await db.execute( select(Asset).where( or_( func.upper(Asset.symbol).contains(q), func.upper(Asset.name).contains(q), ) ).limit(10) ) return list(result.scalars().all())