""" Authentication service: register, login, TOTP, sessions, brute-force protection. """ from __future__ import annotations import base64 import uuid from datetime import datetime, timedelta, timezone from jose import JWTError from redis.asyncio import Redis from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings from app.core.security import ( create_access_token, create_refresh_token, decrypt_field, decode_token, encrypt_field, generate_backup_codes, generate_csrf_token, generate_totp_qr_png, generate_totp_secret, hash_password, hash_token, verify_password, verify_totp, ) from app.db.models.session import Session from app.db.models.user import User class AuthError(Exception): def __init__(self, detail: str, status_code: int = 401): self.detail = detail self.status_code = status_code async def _lockout_key(email: str) -> str: return f"lockout:{email}" async def _check_and_record_failure(redis: Redis, email: str, settings) -> None: key = await _lockout_key(email) attempts = await redis.incr(key) if attempts == 1: await redis.expire(key, settings.lockout_base_seconds) if attempts >= settings.max_login_attempts: lockout_seconds = settings.lockout_base_seconds * (2 ** (attempts - settings.max_login_attempts)) await redis.expire(key, min(lockout_seconds, 86400)) # cap at 24h async def _is_locked_out(redis: Redis, email: str) -> bool: key = await _lockout_key(email) val = await redis.get(key) if val is None: return False settings = get_settings() return int(val) >= settings.max_login_attempts async def register_user(db: AsyncSession, email: str, password: str, display_name: str) -> User: settings = get_settings() # Single-user: block registration if user already exists if not settings.allow_registration: count = await db.scalar(select(func.count()).select_from(User).where(User.deleted_at.is_(None))) if count and count > 0: raise AuthError("Registration is disabled", status_code=403) existing = await db.scalar(select(User).where(User.email == email)) if existing: raise AuthError("Email already registered", status_code=409) now = datetime.now(timezone.utc) user = User( email=email, password_hash=hash_password(password), display_name=display_name, base_currency=settings.base_currency, created_at=now, updated_at=now, ) db.add(user) await db.flush() from app.services.tax_service import seed_default_rates await seed_default_rates(db, user.id) return user async def authenticate_user( db: AsyncSession, redis: Redis, email: str, password: str, ip: str | None, user_agent: str | None, ) -> tuple[User, str, str] | tuple[User, None, None]: """ Returns (user, access_token, refresh_token) if no TOTP required, or (user, None, None) if TOTP challenge needed. Raises AuthError on failure. """ settings = get_settings() if await _is_locked_out(redis, email): raise AuthError("Account temporarily locked due to too many failed attempts", status_code=429) user = await db.scalar( select(User).where(User.email == email, User.deleted_at.is_(None)) ) if not user or not verify_password(password, user.password_hash): await _check_and_record_failure(redis, email, settings) raise AuthError("Invalid email or password") # Clear lockout on success await redis.delete(await _lockout_key(email)) if user.totp_enabled: return user, None, None # Caller creates challenge token tokens = await _create_session(db, user, ip, user_agent) return user, *tokens async def _create_session( db: AsyncSession, user: User, ip: str | None, user_agent: str | None, ) -> tuple[str, str]: settings = get_settings() access_token = create_access_token(str(user.id)) refresh_token = create_refresh_token(str(user.id)) now = datetime.now(timezone.utc) session = Session( user_id=user.id, token_hash=hash_token(access_token), refresh_token_hash=hash_token(refresh_token), ip_address=ip, user_agent=user_agent, last_active_at=now, expires_at=now + timedelta(days=settings.refresh_token_expire_days), created_at=now, ) db.add(session) await db.flush() # Update user login info user.last_login_at = now user.last_login_ip = ip user.updated_at = now return access_token, refresh_token async def complete_totp_login( db: AsyncSession, challenge_token: str, totp_code: str, ip: str | None, user_agent: str | None, ) -> tuple[str, str]: try: payload = decode_token(challenge_token, token_type="totp_challenge") user_id = uuid.UUID(payload["sub"]) except (JWTError, ValueError, KeyError): raise AuthError("Invalid or expired challenge token") user = await db.get(User, user_id) if not user or not user.totp_enabled or not user.totp_secret_enc: raise AuthError("Invalid challenge") secret = decrypt_field(bytes.fromhex(user.totp_secret_enc) if isinstance(user.totp_secret_enc, str) else user.totp_secret_enc) if not verify_totp(secret, totp_code): raise AuthError("Invalid TOTP code") return await _create_session(db, user, ip, user_agent) def create_totp_challenge_token(user_id: uuid.UUID) -> str: from app.core.security import create_access_token from datetime import timedelta from datetime import datetime, timezone from app.config import get_settings from jose import jwt from pathlib import Path settings = get_settings() now = datetime.now(timezone.utc) payload = { "sub": str(user_id), "iat": now, "exp": now + timedelta(minutes=5), "type": "totp_challenge", } private_key = Path(settings.jwt_private_key_file).read_text() return jwt.encode(payload, private_key, algorithm=settings.jwt_algorithm) async def setup_totp(user: User, db: AsyncSession) -> tuple[str, str, list[str]]: """Generate TOTP secret, QR code, and backup codes. Does not enable TOTP yet.""" secret = generate_totp_secret() qr_png = generate_totp_qr_png(secret, user.email) backup_codes = generate_backup_codes(8) return secret, base64.b64encode(qr_png).decode(), backup_codes async def enable_totp(user: User, db: AsyncSession, secret: str, code: str) -> None: if not verify_totp(secret, code): raise AuthError("Invalid TOTP code — setup failed", status_code=400) encrypted = encrypt_field(secret) user.totp_secret_enc = encrypted.hex() user.totp_enabled = True user.updated_at = datetime.now(timezone.utc) await db.flush() async def disable_totp(user: User, db: AsyncSession, password: str) -> None: if not verify_password(password, user.password_hash): raise AuthError("Incorrect password", status_code=400) user.totp_secret_enc = None user.totp_enabled = False user.totp_backup_codes_enc = None user.updated_at = datetime.now(timezone.utc) await db.flush() async def revoke_session(db: AsyncSession, session_id: uuid.UUID, user_id: uuid.UUID) -> None: session = await db.get(Session, session_id) if not session or session.user_id != user_id: raise AuthError("Session not found", status_code=404) session.revoked_at = datetime.now(timezone.utc) await db.flush() async def revoke_all_sessions(db: AsyncSession, user_id: uuid.UUID, except_token_hash: str | None = None) -> None: from sqlalchemy import update stmt = ( update(Session) .where(Session.user_id == user_id, Session.revoked_at.is_(None)) ) if except_token_hash: stmt = stmt.where(Session.token_hash != except_token_hash) stmt = stmt.values(revoked_at=datetime.now(timezone.utc)) await db.execute(stmt) async def get_sessions(db: AsyncSession, user_id: uuid.UUID) -> list[Session]: result = await db.execute( select(Session).where( Session.user_id == user_id, Session.revoked_at.is_(None), Session.expires_at > datetime.now(timezone.utc), ).order_by(Session.created_at.desc()) ) return list(result.scalars())