import uuid from datetime import date, datetime, timezone from decimal import Decimal from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models.asset import Asset from app.db.models.asset_price import AssetPrice from app.db.models.investment_holding import InvestmentHolding from app.db.models.investment_transaction import InvestmentTransaction from app.schemas.investment import ( HoldingCreate, HoldingResponse, InvestmentTxnCreate, PerformanceMetrics, PortfolioSummary, ) async def _get_asset(db: AsyncSession, asset_id: uuid.UUID) -> Asset | None: result = await db.execute(select(Asset).where(Asset.id == asset_id)) return result.scalar_one_or_none() def _holding_to_response(holding: InvestmentHolding, asset: Asset) -> HoldingResponse: cost_basis_total = holding.quantity * holding.avg_cost_basis current_price = asset.last_price current_value = holding.quantity * current_price if current_price else None unrealised_gain = (current_value - cost_basis_total) if current_value is not None else None unrealised_gain_pct = None if unrealised_gain is not None and cost_basis_total > 0: unrealised_gain_pct = (unrealised_gain / cost_basis_total * 100).quantize(Decimal("0.01")) return HoldingResponse( id=holding.id, account_id=holding.account_id, asset_id=holding.asset_id, symbol=asset.symbol, asset_name=asset.name, asset_type=asset.type, quantity=holding.quantity, avg_cost_basis=holding.avg_cost_basis, current_price=current_price, current_value=current_value, cost_basis_total=cost_basis_total, unrealised_gain=unrealised_gain, unrealised_gain_pct=unrealised_gain_pct, currency=holding.currency, price_change_24h=asset.price_change_24h, ) async def get_portfolio(db: AsyncSession, user_id: uuid.UUID) -> PortfolioSummary: result = await db.execute( select(InvestmentHolding).where( InvestmentHolding.user_id == user_id, InvestmentHolding.quantity > 0, ) ) holdings = result.scalars().all() responses = [] total_value = Decimal("0") total_cost = Decimal("0") for h in holdings: asset = await _get_asset(db, h.asset_id) if not asset: continue r = _holding_to_response(h, asset) responses.append(r) total_cost += r.cost_basis_total if r.current_value is not None: total_value += r.current_value total_gain = total_value - total_cost total_gain_pct = (total_gain / total_cost * 100).quantize(Decimal("0.01")) if total_cost > 0 else Decimal("0") return PortfolioSummary( total_value=total_value, total_cost=total_cost, total_gain=total_gain, total_gain_pct=total_gain_pct, currency="GBP", holdings=responses, ) async def get_holding(db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID) -> InvestmentHolding | None: result = await db.execute( select(InvestmentHolding).where( InvestmentHolding.id == holding_id, InvestmentHolding.user_id == user_id, ) ) return result.scalar_one_or_none() async def create_holding(db: AsyncSession, user_id: uuid.UUID, data: HoldingCreate) -> InvestmentHolding: now = datetime.now(timezone.utc) # Check if holding already exists for this account+asset result = await db.execute( select(InvestmentHolding).where( InvestmentHolding.user_id == user_id, InvestmentHolding.account_id == data.account_id, InvestmentHolding.asset_id == data.asset_id, ) ) existing = result.scalar_one_or_none() if existing: return existing holding = InvestmentHolding( id=uuid.uuid4(), user_id=user_id, account_id=data.account_id, asset_id=data.asset_id, quantity=data.quantity, avg_cost_basis=data.avg_cost_basis, currency=data.currency, created_at=now, updated_at=now, ) db.add(holding) await db.flush() await db.refresh(holding) return holding async def add_investment_transaction( db: AsyncSession, user_id: uuid.UUID, data: InvestmentTxnCreate ) -> InvestmentTransaction: holding = await get_holding(db, user_id, data.holding_id) if not holding: raise ValueError("Holding not found") total = data.quantity * data.price + data.fees txn = InvestmentTransaction( id=uuid.uuid4(), user_id=user_id, holding_id=data.holding_id, type=data.type, quantity=data.quantity, price=data.price, fees=data.fees, total_amount=total, currency=data.currency, date=data.date, created_at=datetime.now(timezone.utc), ) db.add(txn) # Update holding quantity and avg cost basis if data.type == "buy" or data.type == "transfer_in": new_qty = holding.quantity + data.quantity if new_qty > 0: holding.avg_cost_basis = ( (holding.quantity * holding.avg_cost_basis + data.quantity * data.price) / new_qty ) holding.quantity = new_qty elif data.type == "sell" or data.type == "transfer_out": holding.quantity = max(Decimal("0"), holding.quantity - data.quantity) elif data.type == "split": if data.price > 0: holding.quantity = holding.quantity * data.quantity holding.avg_cost_basis = holding.avg_cost_basis / data.quantity # dividend and fee don't affect quantity/cost basis holding.updated_at = datetime.now(timezone.utc) await db.flush() await db.refresh(txn) return txn async def list_investment_transactions( db: AsyncSession, user_id: uuid.UUID, holding_id: uuid.UUID ) -> list[InvestmentTransaction]: result = await db.execute( select(InvestmentTransaction) .where( InvestmentTransaction.user_id == user_id, InvestmentTransaction.holding_id == holding_id, ) .order_by(InvestmentTransaction.date.desc()) ) return list(result.scalars().all()) async def get_performance(db: AsyncSession, user_id: uuid.UUID) -> PerformanceMetrics: portfolio = await get_portfolio(db, user_id) total_return = portfolio.total_gain total_return_pct = portfolio.total_gain_pct return PerformanceMetrics( twrr=None, # full TWRR requires snapshot history — placeholder total_return=total_return, total_return_pct=total_return_pct, currency="GBP", ) async def get_or_create_asset( db: AsyncSession, symbol: str, name: str, asset_type: str, currency: str, data_source: str, data_source_id: str | None, exchange: str | None = None, ) -> Asset: result = await db.execute( select(Asset).where(Asset.symbol == symbol.upper(), Asset.data_source == data_source) ) existing = result.scalar_one_or_none() if existing: return existing now = datetime.now(timezone.utc) asset = Asset( id=uuid.uuid4(), symbol=symbol.upper(), name=name, type=asset_type, currency=currency, exchange=exchange, data_source=data_source, data_source_id=data_source_id, is_active=True, created_at=now, updated_at=now, ) db.add(asset) await db.flush() await db.refresh(asset) return asset async def update_asset_price( db: AsyncSession, asset: Asset, price: Decimal, change_24h: Decimal | None ) -> None: asset.last_price = price asset.price_change_24h = change_24h asset.last_price_at = datetime.now(timezone.utc) asset.updated_at = datetime.now(timezone.utc) await db.flush() async def upsert_price_history(db: AsyncSession, asset_id: uuid.UUID, rows: list[dict]) -> int: count = 0 for row in rows: result = await db.execute( select(AssetPrice).where(AssetPrice.asset_id == asset_id, AssetPrice.date == row["date"]) ) existing = result.scalar_one_or_none() if existing: existing.open = row["open"] existing.high = row["high"] existing.low = row["low"] existing.close = row["close"] existing.volume = row["volume"] else: db.add(AssetPrice( id=uuid.uuid4(), asset_id=asset_id, date=row["date"], open=row.get("open"), high=row.get("high"), low=row.get("low"), close=row["close"], volume=row.get("volume"), created_at=datetime.now(timezone.utc), )) count += 1 await db.flush() return count async def get_price_history( db: AsyncSession, asset_id: uuid.UUID, days: int = 365 ) -> list[AssetPrice]: from datetime import timedelta cutoff = date.today() - timedelta(days=days) result = await db.execute( select(AssetPrice) .where(AssetPrice.asset_id == asset_id, AssetPrice.date >= cutoff) .order_by(AssetPrice.date.asc()) ) return list(result.scalars().all()) async def search_assets(db: AsyncSession, query: str) -> list[Asset]: from sqlalchemy import or_, func q = query.strip().upper() result = await db.execute( select(Asset).where( or_( func.upper(Asset.symbol).contains(q), func.upper(Asset.name).contains(q), ) ).limit(10) ) return list(result.scalars().all())