from __future__ import annotations import hashlib import uuid from datetime import datetime, timezone from decimal import Decimal from sqlalchemy import and_, or_, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import decrypt_field, encrypt_field from app.db.models.transaction import Transaction from app.schemas.transaction import TransactionCreate, TransactionFilter, TransactionUpdate from app.services.account_service import recalculate_balance class TransactionError(Exception): def __init__(self, detail: str, status_code: int = 400): self.detail = detail self.status_code = status_code def _enc(v: str | None) -> bytes | None: return encrypt_field(v) if v else None def _dec(v: bytes | None) -> str | None: return decrypt_field(v) if v else None def _to_response(t: Transaction) -> dict: return { "id": t.id, "account_id": t.account_id, "transfer_account_id": t.transfer_account_id, "category_id": t.category_id, "type": t.type, "status": t.status, "amount": t.amount, "amount_base": t.amount_base, "currency": t.currency, "base_currency": t.base_currency, "exchange_rate": t.exchange_rate, "date": t.date, "description": _dec(t.description_enc) or "", "merchant": _dec(t.merchant_enc), "notes": _dec(t.notes_enc), "tags": t.tags or [], "is_recurring": t.is_recurring, "attachment_refs": t.attachment_refs or [], "created_at": t.created_at, "updated_at": t.updated_at, } async def create_transaction( db: AsyncSession, user_id: uuid.UUID, data: TransactionCreate, base_currency: str, ) -> dict: now = datetime.now(timezone.utc) amount = data.amount # For transfers, create mirrored entry on destination account txn = Transaction( user_id=user_id, account_id=data.account_id, transfer_account_id=data.transfer_account_id, category_id=data.category_id, type=data.type, status=data.status, amount=amount, amount_base=amount, # Phase 3: convert via FX rate currency=data.currency, base_currency=base_currency, exchange_rate=Decimal("1") if data.currency == base_currency else None, date=data.date, description_enc=encrypt_field(data.description), merchant_enc=_enc(data.merchant), notes_enc=_enc(data.notes), tags=data.tags, is_recurring=data.is_recurring, recurring_rule=data.recurring_rule, created_at=now, updated_at=now, ) db.add(txn) await db.flush() # If transfer, create the counter-entry on the destination account if data.type == "transfer" and data.transfer_account_id: counter = Transaction( user_id=user_id, account_id=data.transfer_account_id, transfer_account_id=data.account_id, category_id=data.category_id, type="transfer", status=data.status, amount=-amount, # opposite sign amount_base=-amount, currency=data.currency, base_currency=base_currency, exchange_rate=Decimal("1") if data.currency == base_currency else None, date=data.date, description_enc=encrypt_field(data.description), merchant_enc=_enc(data.merchant), notes_enc=_enc(data.notes), tags=data.tags, is_recurring=False, created_at=now, updated_at=now, ) db.add(counter) await db.flush() await recalculate_balance(db, data.transfer_account_id) await recalculate_balance(db, data.account_id) return _to_response(txn) async def list_transactions( db: AsyncSession, user_id: uuid.UUID, filters: TransactionFilter, ) -> dict: conditions = [ Transaction.user_id == user_id, Transaction.deleted_at.is_(None), ] if filters.account_id: conditions.append(Transaction.account_id == filters.account_id) if filters.category_id: conditions.append(Transaction.category_id == filters.category_id) if filters.type: conditions.append(Transaction.type == filters.type) if filters.status: conditions.append(Transaction.status == filters.status) if filters.date_from: conditions.append(Transaction.date >= filters.date_from) if filters.date_to: conditions.append(Transaction.date <= filters.date_to) if filters.min_amount is not None: conditions.append(Transaction.amount >= filters.min_amount) if filters.max_amount is not None: conditions.append(Transaction.amount <= filters.max_amount) if filters.is_recurring is not None: conditions.append(Transaction.is_recurring == filters.is_recurring) query = select(Transaction).where(and_(*conditions)).order_by(Transaction.date.desc(), Transaction.created_at.desc()) # Count total from sqlalchemy import func count_result = await db.execute(select(func.count()).select_from(query.subquery())) total = count_result.scalar_one() # Paginate offset = (filters.page - 1) * filters.page_size query = query.offset(offset).limit(filters.page_size) result = await db.execute(query) items = [_to_response(t) for t in result.scalars()] # Filter by search (post-decrypt — Phase 3 will add FTS) if filters.search: term = filters.search.lower() items = [ t for t in items if term in t["description"].lower() or (t["merchant"] and term in t["merchant"].lower()) ] return { "items": items, "total": total, "page": filters.page, "page_size": filters.page_size, "pages": max(1, -(-total // filters.page_size)), } async def get_transaction(db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID) -> Transaction: result = await db.execute( select(Transaction).where( Transaction.id == txn_id, Transaction.user_id == user_id, Transaction.deleted_at.is_(None), ) ) txn = result.scalar_one_or_none() if not txn: raise TransactionError("Transaction not found", status_code=404) return txn async def update_transaction( db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID, data: TransactionUpdate, base_currency: str, ) -> dict: txn = await get_transaction(db, txn_id, user_id) now = datetime.now(timezone.utc) old_account_id = txn.account_id if data.category_id is not None: txn.category_id = data.category_id if data.status is not None: txn.status = data.status if data.amount is not None: txn.amount = data.amount txn.amount_base = data.amount if data.date is not None: txn.date = data.date if data.description is not None: txn.description_enc = encrypt_field(data.description) if data.merchant is not None: txn.merchant_enc = _enc(data.merchant) if data.notes is not None: txn.notes_enc = _enc(data.notes) if data.tags is not None: txn.tags = data.tags txn.updated_at = now await db.flush() await recalculate_balance(db, old_account_id) return _to_response(txn) async def delete_transaction(db: AsyncSession, txn_id: uuid.UUID, user_id: uuid.UUID) -> None: txn = await get_transaction(db, txn_id, user_id) account_id = txn.account_id txn.deleted_at = datetime.now(timezone.utc) txn.updated_at = datetime.now(timezone.utc) await db.flush() await recalculate_balance(db, account_id) async def import_csv( db: AsyncSession, user_id: uuid.UUID, account_id: uuid.UUID, rows: list[dict], base_currency: str, ) -> dict: """ Import transactions from parsed CSV rows. Each row must have: date, description, amount Optional: merchant, notes, category_name Returns counts of imported vs skipped (duplicates). """ imported = 0 skipped = 0 now = datetime.now(timezone.utc) for row in rows: # Build dedup hash from date + description + amount raw = f"{row['date']}|{row['description']}|{row['amount']}" import_hash = hashlib.sha256(raw.encode()).hexdigest() # Check duplicate exists = await db.scalar( select(Transaction.id).where( Transaction.user_id == user_id, Transaction.import_hash == import_hash, ) ) if exists: skipped += 1 continue try: amount = Decimal(str(row["amount"])) from datetime import date as date_type import dateutil.parser txn_date = dateutil.parser.parse(str(row["date"])).date() except Exception: skipped += 1 continue txn_type = "income" if amount > 0 else "expense" txn = Transaction( user_id=user_id, account_id=account_id, type=txn_type, status="cleared", amount=amount, amount_base=amount, currency=row.get("currency", base_currency), base_currency=base_currency, exchange_rate=Decimal("1"), date=txn_date, description_enc=encrypt_field(str(row.get("description", ""))), merchant_enc=_enc(row.get("merchant")), notes_enc=_enc(row.get("notes")), tags=[], is_recurring=False, import_hash=import_hash, created_at=now, updated_at=now, ) db.add(txn) imported += 1 await db.flush() if imported > 0: await recalculate_balance(db, account_id) return {"imported": imported, "skipped": skipped}