Files
Metamorph/backend/app/services/auth.py

225 lines
6.7 KiB
Python
Raw Normal View History

2026-05-11 06:05:27 +02:00
"""Auth domain logic: login, refresh rotation, logout, change_password.
Returns lightweight DTOs (dicts) the API layer is responsible for HTTP shape.
Raises plain `ValueError` / `LookupError` / `PermissionError` and lets the API
layer translate them into HTTP statuses.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select
from app.core.jwt_tokens import (
REFRESH_TOKEN_TTL,
decode_token,
encode_token,
generate_jti,
)
from app.core.security import (
hash_opaque_token,
hash_password,
needs_rehash,
verify_opaque_token,
verify_password,
)
from app.db.session import session_scope
from app.models.auth import RefreshToken, User
class AuthError(Exception):
"""Base for auth-flow exceptions; HTTP layer maps to 401/403."""
class InvalidCredentials(AuthError):
pass
class TokenRevoked(AuthError):
pass
@dataclass
class TokenPair:
access_token: str
refresh_token: str
refresh_expires_at: datetime
user_id: uuid.UUID
def _now() -> datetime:
return datetime.now(tz=timezone.utc)
# === Login ===================================================================
def login(email: str, password: str) -> TokenPair:
email_norm = email.strip().lower()
with session_scope() as s:
user = s.scalar(
select(User).where(
User.email == email_norm,
User.deleted_at.is_(None),
User.is_active.is_(True),
)
)
if user is None or not verify_password(user.password_hash, password):
# Same error for "no such user" and "wrong password" — no account enumeration.
raise InvalidCredentials("invalid credentials")
if needs_rehash(user.password_hash):
user.password_hash = hash_password(password)
return _issue_token_pair(s, user.id)
# === Refresh rotation ========================================================
def refresh(raw_refresh_token: str) -> TokenPair:
"""Validate the refresh token, revoke the old one, mint a new pair.
Detects token reuse: if a refresh token that has already been rotated is
presented again, we revoke the entire chain (treat as compromise).
"""
try:
claims = decode_token(raw_refresh_token, expected_type="refresh")
except Exception as e:
raise InvalidCredentials("invalid refresh token") from e
token_hash = hash_opaque_token(raw_refresh_token)
with session_scope() as s:
rt = s.scalar(
select(RefreshToken).where(
RefreshToken.jti == claims.jti,
RefreshToken.token_hash == token_hash,
)
)
if rt is None:
raise InvalidCredentials("refresh token not recognised")
if rt.revoked_at is not None:
# Reuse of a revoked token → likely compromise. Cascade-revoke chain.
_revoke_chain(s, rt)
raise TokenRevoked("refresh token has been revoked")
if rt.expires_at <= _now():
raise InvalidCredentials("refresh token expired")
# Rotate: mark old as revoked + replaced_by, mint new.
new_pair = _issue_token_pair(s, rt.user_id)
new_jti = decode_token(new_pair.refresh_token, expected_type="refresh").jti
new_rt = s.scalar(select(RefreshToken).where(RefreshToken.jti == new_jti))
rt.revoked_at = _now()
rt.replaced_by_id = new_rt.id if new_rt else None
return new_pair
# === Logout ==================================================================
def logout(raw_refresh_token: str) -> None:
"""Revoke the refresh token. Idempotent — silently no-ops on bad tokens."""
try:
claims = decode_token(raw_refresh_token, expected_type="refresh")
except Exception:
return
with session_scope() as s:
rt = s.scalar(select(RefreshToken).where(RefreshToken.jti == claims.jti))
if rt is not None and rt.revoked_at is None:
rt.revoked_at = _now()
def logout_all_for_user(user_id: uuid.UUID) -> int:
"""Revoke every active refresh token for a user. Returns count revoked."""
now = _now()
with session_scope() as s:
active = s.scalars(
select(RefreshToken).where(
RefreshToken.user_id == user_id,
RefreshToken.revoked_at.is_(None),
)
).all()
for rt in active:
rt.revoked_at = now
return len(active)
# === Password change ========================================================
def change_password(user_id: uuid.UUID, current: str, new: str) -> None:
if len(new) < 8:
raise ValueError("new password must be at least 8 characters")
with session_scope() as s:
user = s.get(User, user_id)
if user is None or user.deleted_at is not None or not user.is_active:
raise LookupError("user not found")
if not verify_password(user.password_hash, current):
raise InvalidCredentials("current password is incorrect")
user.password_hash = hash_password(new)
# Force re-login on every other device.
logout_all_for_user(user_id)
# === Helpers =================================================================
def _issue_token_pair(s, user_id: uuid.UUID) -> TokenPair:
"""Issue a fresh access + refresh pair. The refresh row is persisted."""
access_jti = generate_jti()
refresh_jti = generate_jti()
access_token, _ = encode_token(user_id, "access", jti=access_jti)
refresh_token, refresh_claims = encode_token(user_id, "refresh", jti=refresh_jti)
s.add(
RefreshToken(
user_id=user_id,
jti=refresh_jti,
token_hash=hash_opaque_token(refresh_token),
issued_at=refresh_claims.iat,
expires_at=refresh_claims.exp,
)
)
s.flush() # ensure the row gets an id before we return
return TokenPair(
access_token=access_token,
refresh_token=refresh_token,
refresh_expires_at=refresh_claims.exp,
user_id=user_id,
)
def _revoke_chain(s, rt: RefreshToken) -> None:
"""When reuse is detected, revoke this token and its replacement chain."""
seen: set[uuid.UUID] = set()
cur: RefreshToken | None = rt
while cur is not None and cur.id not in seen:
seen.add(cur.id)
if cur.revoked_at is None:
cur.revoked_at = _now()
if cur.replaced_by_id:
cur = s.get(RefreshToken, cur.replaced_by_id)
else:
cur = None
__all__ = [
"AuthError",
"InvalidCredentials",
"TokenRevoked",
"TokenPair",
"REFRESH_TOKEN_TTL",
"login",
"refresh",
"logout",
"logout_all_for_user",
"change_password",
]