""" FastAPI dependency injection: DB session, Redis, current authenticated user. """ from __future__ import annotations import uuid from typing import AsyncGenerator from fastapi import Cookie, Depends, HTTPException, Request, status from jose import JWTError from redis.asyncio import Redis from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import decode_token, hash_token from app.db.models.session import Session from app.db.models.user import User # These are set by main.py lifespan _session_factory = None _redis_client: Redis | None = None def set_session_factory(factory): global _session_factory _session_factory = factory def get_session_factory(): return _session_factory def set_redis_client(client: Redis): global _redis_client _redis_client = client async def get_db() -> AsyncGenerator[AsyncSession, None]: async with _session_factory() as session: yield session async def get_redis() -> Redis: return _redis_client def _extract_bearer(request: Request) -> str | None: auth = request.headers.get("Authorization", "") if auth.startswith("Bearer "): return auth[7:] return None async def get_current_user( request: Request, db: AsyncSession = Depends(get_db), ) -> User: token = _extract_bearer(request) if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") try: payload = decode_token(token, token_type="access") user_id = uuid.UUID(payload["sub"]) except (JWTError, ValueError, KeyError): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") token_hash = hash_token(token) from datetime import datetime, timezone result = await db.execute( select(Session).where( Session.user_id == user_id, Session.token_hash == token_hash, Session.revoked_at.is_(None), Session.expires_at > datetime.now(timezone.utc), ) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Session expired or revoked") result = await db.execute( select(User).where(User.id == user_id, User.deleted_at.is_(None)) ) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") # Set RLS context so PostgreSQL RLS policies apply await db.execute(text(f"SET LOCAL app.current_user_id = '{user_id}'")) return user