98 lines
2.7 KiB
Python
98 lines
2.7 KiB
Python
|
|
"""JWT encoding / decoding.
|
||
|
|
|
||
|
|
Two token types:
|
||
|
|
- `access` — short-lived (1 h), in `Authorization: Bearer ...` headers, kept
|
||
|
|
client-side **in memory** only (cf. spec §M2).
|
||
|
|
- `refresh` — long-lived (30 d), in an HTTPOnly Secure SameSite=Strict cookie
|
||
|
|
scoped to `/api/v1/auth/`. Rotated on every successful refresh,
|
||
|
|
old `jti` revoked.
|
||
|
|
|
||
|
|
We sign HS256 with `settings.JWT_SECRET`. The `jti` claim links each token to
|
||
|
|
its DB row in `refresh_tokens` for revocation; access tokens are stateless.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import secrets
|
||
|
|
import uuid
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
from typing import Literal
|
||
|
|
|
||
|
|
import jwt
|
||
|
|
|
||
|
|
from app.core.config import settings
|
||
|
|
|
||
|
|
ACCESS_TOKEN_TTL = timedelta(hours=1)
|
||
|
|
REFRESH_TOKEN_TTL = timedelta(days=30)
|
||
|
|
ALGORITHM = "HS256"
|
||
|
|
ISSUER = "metamorph"
|
||
|
|
|
||
|
|
|
||
|
|
TokenType = Literal["access", "refresh"]
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class TokenClaims:
|
||
|
|
sub: str # user id (UUID as string)
|
||
|
|
type: TokenType
|
||
|
|
jti: str
|
||
|
|
iat: datetime
|
||
|
|
exp: datetime
|
||
|
|
|
||
|
|
|
||
|
|
def _now() -> datetime:
|
||
|
|
return datetime.now(tz=timezone.utc)
|
||
|
|
|
||
|
|
|
||
|
|
def generate_jti() -> str:
|
||
|
|
"""Compact, URL-safe random identifier (≈22 chars)."""
|
||
|
|
return secrets.token_urlsafe(16)
|
||
|
|
|
||
|
|
|
||
|
|
def encode_token(
|
||
|
|
user_id: uuid.UUID | str,
|
||
|
|
token_type: TokenType,
|
||
|
|
*,
|
||
|
|
jti: str | None = None,
|
||
|
|
) -> tuple[str, TokenClaims]:
|
||
|
|
"""Return `(jwt_string, claims)`. `jti` is generated if not provided."""
|
||
|
|
now = _now()
|
||
|
|
ttl = ACCESS_TOKEN_TTL if token_type == "access" else REFRESH_TOKEN_TTL
|
||
|
|
claims = TokenClaims(
|
||
|
|
sub=str(user_id),
|
||
|
|
type=token_type,
|
||
|
|
jti=jti or generate_jti(),
|
||
|
|
iat=now,
|
||
|
|
exp=now + ttl,
|
||
|
|
)
|
||
|
|
payload = {
|
||
|
|
"iss": ISSUER,
|
||
|
|
"sub": claims.sub,
|
||
|
|
"type": claims.type,
|
||
|
|
"jti": claims.jti,
|
||
|
|
"iat": int(claims.iat.timestamp()),
|
||
|
|
"exp": int(claims.exp.timestamp()),
|
||
|
|
}
|
||
|
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM), claims
|
||
|
|
|
||
|
|
|
||
|
|
def decode_token(token: str, *, expected_type: TokenType) -> TokenClaims:
|
||
|
|
"""Decode and validate a JWT. Raises `jwt.PyJWTError` on any failure."""
|
||
|
|
payload = jwt.decode(
|
||
|
|
token,
|
||
|
|
settings.JWT_SECRET,
|
||
|
|
algorithms=[ALGORITHM],
|
||
|
|
issuer=ISSUER,
|
||
|
|
options={"require": ["sub", "type", "jti", "iat", "exp"]},
|
||
|
|
)
|
||
|
|
if payload["type"] != expected_type:
|
||
|
|
raise jwt.InvalidTokenError(f"expected {expected_type} token, got {payload['type']}")
|
||
|
|
return TokenClaims(
|
||
|
|
sub=payload["sub"],
|
||
|
|
type=payload["type"],
|
||
|
|
jti=payload["jti"],
|
||
|
|
iat=datetime.fromtimestamp(payload["iat"], tz=timezone.utc),
|
||
|
|
exp=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||
|
|
)
|