"""Auth domain logic: login, refresh rotation, logout, change_password. Returns lightweight DTOs (dicts) — the API layer is responsible for HTTP shape. Raises plain `ValueError` / `LookupError` / `PermissionError` and lets the API layer translate them into HTTP statuses. """ from __future__ import annotations import uuid from dataclasses import dataclass from datetime import datetime, timezone from sqlalchemy import select from app.core.jwt_tokens import ( REFRESH_TOKEN_TTL, decode_token, encode_token, generate_jti, ) from app.core.security import ( hash_opaque_token, hash_password, needs_rehash, verify_opaque_token, verify_password, ) from app.db.session import session_scope from app.models.auth import RefreshToken, User class AuthError(Exception): """Base for auth-flow exceptions; HTTP layer maps to 401/403.""" class InvalidCredentials(AuthError): pass class TokenRevoked(AuthError): pass @dataclass class TokenPair: access_token: str refresh_token: str refresh_expires_at: datetime user_id: uuid.UUID def _now() -> datetime: return datetime.now(tz=timezone.utc) # === Login =================================================================== def login(email: str, password: str) -> TokenPair: email_norm = email.strip().lower() with session_scope() as s: user = s.scalar( select(User).where( User.email == email_norm, User.deleted_at.is_(None), User.is_active.is_(True), ) ) if user is None or not verify_password(user.password_hash, password): # Same error for "no such user" and "wrong password" — no account enumeration. raise InvalidCredentials("invalid credentials") if needs_rehash(user.password_hash): user.password_hash = hash_password(password) return _issue_token_pair(s, user.id) # === Refresh rotation ======================================================== def refresh(raw_refresh_token: str) -> TokenPair: """Validate the refresh token, revoke the old one, mint a new pair. Detects token reuse: if a refresh token that has already been rotated is presented again, we revoke the entire chain (treat as compromise). """ try: claims = decode_token(raw_refresh_token, expected_type="refresh") except Exception as e: raise InvalidCredentials("invalid refresh token") from e token_hash = hash_opaque_token(raw_refresh_token) with session_scope() as s: rt = s.scalar( select(RefreshToken).where( RefreshToken.jti == claims.jti, RefreshToken.token_hash == token_hash, ) ) if rt is None: raise InvalidCredentials("refresh token not recognised") if rt.revoked_at is not None: # Reuse of a revoked token → likely compromise. Cascade-revoke chain. _revoke_chain(s, rt) raise TokenRevoked("refresh token has been revoked") if rt.expires_at <= _now(): raise InvalidCredentials("refresh token expired") # Rotate: mark old as revoked + replaced_by, mint new. new_pair = _issue_token_pair(s, rt.user_id) new_jti = decode_token(new_pair.refresh_token, expected_type="refresh").jti new_rt = s.scalar(select(RefreshToken).where(RefreshToken.jti == new_jti)) rt.revoked_at = _now() rt.replaced_by_id = new_rt.id if new_rt else None return new_pair # === Logout ================================================================== def logout(raw_refresh_token: str) -> None: """Revoke the refresh token. Idempotent — silently no-ops on bad tokens.""" try: claims = decode_token(raw_refresh_token, expected_type="refresh") except Exception: return with session_scope() as s: rt = s.scalar(select(RefreshToken).where(RefreshToken.jti == claims.jti)) if rt is not None and rt.revoked_at is None: rt.revoked_at = _now() def logout_all_for_user(user_id: uuid.UUID) -> int: """Revoke every active refresh token for a user. Returns count revoked.""" now = _now() with session_scope() as s: active = s.scalars( select(RefreshToken).where( RefreshToken.user_id == user_id, RefreshToken.revoked_at.is_(None), ) ).all() for rt in active: rt.revoked_at = now return len(active) # === Password change ======================================================== def change_password(user_id: uuid.UUID, current: str, new: str) -> None: if len(new) < 8: raise ValueError("new password must be at least 8 characters") with session_scope() as s: user = s.get(User, user_id) if user is None or user.deleted_at is not None or not user.is_active: raise LookupError("user not found") if not verify_password(user.password_hash, current): raise InvalidCredentials("current password is incorrect") user.password_hash = hash_password(new) # Force re-login on every other device. logout_all_for_user(user_id) # === Helpers ================================================================= def _issue_token_pair(s, user_id: uuid.UUID) -> TokenPair: """Issue a fresh access + refresh pair. The refresh row is persisted.""" access_jti = generate_jti() refresh_jti = generate_jti() access_token, _ = encode_token(user_id, "access", jti=access_jti) refresh_token, refresh_claims = encode_token(user_id, "refresh", jti=refresh_jti) s.add( RefreshToken( user_id=user_id, jti=refresh_jti, token_hash=hash_opaque_token(refresh_token), issued_at=refresh_claims.iat, expires_at=refresh_claims.exp, ) ) s.flush() # ensure the row gets an id before we return return TokenPair( access_token=access_token, refresh_token=refresh_token, refresh_expires_at=refresh_claims.exp, user_id=user_id, ) def _revoke_chain(s, rt: RefreshToken) -> None: """When reuse is detected, revoke this token and its replacement chain.""" seen: set[uuid.UUID] = set() cur: RefreshToken | None = rt while cur is not None and cur.id not in seen: seen.add(cur.id) if cur.revoked_at is None: cur.revoked_at = _now() if cur.replaced_by_id: cur = s.get(RefreshToken, cur.replaced_by_id) else: cur = None __all__ = [ "AuthError", "InvalidCredentials", "TokenRevoked", "TokenPair", "REFRESH_TOKEN_TTL", "login", "refresh", "logout", "logout_all_for_user", "change_password", ]