"""JWT encoding / decoding. Two token types: - `access` — short-lived (1 h), in `Authorization: Bearer ...` headers, kept client-side **in memory** only (cf. spec §M2). - `refresh` — long-lived (30 d), in an HTTPOnly Secure SameSite=Strict cookie scoped to `/api/v1/auth/`. Rotated on every successful refresh, old `jti` revoked. We sign HS256 with `settings.JWT_SECRET`. The `jti` claim links each token to its DB row in `refresh_tokens` for revocation; access tokens are stateless. """ from __future__ import annotations import secrets import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Literal import jwt from app.core.config import settings ACCESS_TOKEN_TTL = timedelta(hours=1) REFRESH_TOKEN_TTL = timedelta(days=30) ALGORITHM = "HS256" ISSUER = "metamorph" TokenType = Literal["access", "refresh"] @dataclass(frozen=True) class TokenClaims: sub: str # user id (UUID as string) type: TokenType jti: str iat: datetime exp: datetime def _now() -> datetime: return datetime.now(tz=timezone.utc) def generate_jti() -> str: """Compact, URL-safe random identifier (≈22 chars).""" return secrets.token_urlsafe(16) def encode_token( user_id: uuid.UUID | str, token_type: TokenType, *, jti: str | None = None, ) -> tuple[str, TokenClaims]: """Return `(jwt_string, claims)`. `jti` is generated if not provided.""" now = _now() ttl = ACCESS_TOKEN_TTL if token_type == "access" else REFRESH_TOKEN_TTL claims = TokenClaims( sub=str(user_id), type=token_type, jti=jti or generate_jti(), iat=now, exp=now + ttl, ) payload = { "iss": ISSUER, "sub": claims.sub, "type": claims.type, "jti": claims.jti, "iat": int(claims.iat.timestamp()), "exp": int(claims.exp.timestamp()), } return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM), claims def decode_token(token: str, *, expected_type: TokenType) -> TokenClaims: """Decode and validate a JWT. Raises `jwt.PyJWTError` on any failure.""" payload = jwt.decode( token, settings.JWT_SECRET, algorithms=[ALGORITHM], issuer=ISSUER, options={"require": ["sub", "type", "jti", "iat", "exp"]}, ) if payload["type"] != expected_type: raise jwt.InvalidTokenError(f"expected {expected_type} token, got {payload['type']}") return TokenClaims( sub=payload["sub"], type=payload["type"], jti=payload["jti"], iat=datetime.fromtimestamp(payload["iat"], tz=timezone.utc), exp=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), )