"""Invitation flow: admin issues a one-shot URL token, invitee accepts. The raw token is shown to the admin once (returned by `create_invitation`) and never persisted — only its SHA-256 lives in the DB. Pre-assigned groups are attached at creation and applied at acceptance. """ from __future__ import annotations import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Iterable from sqlalchemy import select from app.core.security import ( generate_opaque_token, hash_opaque_token, hash_password, ) from app.db.session import session_scope from app.models.auth import Group, Invitation, InvitationGroup, User, UserGroup INVITATION_TTL = timedelta(days=7) class InvitationError(Exception): pass class InvitationExpired(InvitationError): pass class InvitationConsumed(InvitationError): pass class InvitationRevoked(InvitationError): pass @dataclass class InvitationCreated: invitation_id: uuid.UUID raw_token: str expires_at: datetime @dataclass class InvitationPreview: email_hint: str | None expires_at: datetime groups: list[str] is_valid: bool reason: str | None def _now() -> datetime: return datetime.now(tz=timezone.utc) def create_invitation( *, created_by_user_id: uuid.UUID, email_hint: str | None, group_ids: Iterable[uuid.UUID] = (), ttl: timedelta = INVITATION_TTL, ) -> InvitationCreated: raw = generate_opaque_token() expires_at = _now() + ttl with session_scope() as s: inv = Invitation( token_hash=hash_opaque_token(raw), email_hint=email_hint.strip().lower() if email_hint else None, created_by_user_id=created_by_user_id, expires_at=expires_at, ) s.add(inv) s.flush() for gid in group_ids: s.add(InvitationGroup(invitation_id=inv.id, group_id=gid)) return InvitationCreated( invitation_id=inv.id, raw_token=raw, expires_at=expires_at, ) def _load_by_token(s, raw_token: str) -> Invitation | None: return s.scalar( select(Invitation).where(Invitation.token_hash == hash_opaque_token(raw_token)) ) def preview(raw_token: str) -> InvitationPreview: with session_scope() as s: inv = _load_by_token(s, raw_token) if inv is None: return InvitationPreview(None, _now(), [], False, "not_found") groups = [g.name for g in inv.pre_assigned_groups] if inv.revoked_at is not None: return InvitationPreview(inv.email_hint, inv.expires_at, groups, False, "revoked") if inv.consumed_at is not None: return InvitationPreview(inv.email_hint, inv.expires_at, groups, False, "consumed") if inv.expires_at <= _now(): return InvitationPreview(inv.email_hint, inv.expires_at, groups, False, "expired") return InvitationPreview(inv.email_hint, inv.expires_at, groups, True, None) def accept(raw_token: str, *, email: str, password: str, display_name: str | None) -> uuid.UUID: """Create the user, attach pre-assigned groups, mark invitation consumed.""" if len(password) < 8: raise ValueError("password must be at least 8 characters") email_norm = email.strip().lower() with session_scope() as s: inv = _load_by_token(s, raw_token) if inv is None: raise InvitationError("invitation not found") if inv.revoked_at is not None: raise InvitationRevoked("invitation revoked") if inv.consumed_at is not None: raise InvitationConsumed("invitation already consumed") if inv.expires_at <= _now(): raise InvitationExpired("invitation expired") # Email must not be already in use among active users. existing = s.scalar( select(User).where(User.email == email_norm, User.deleted_at.is_(None)) ) if existing is not None: raise ValueError("email already in use") user = User( email=email_norm, display_name=(display_name or "").strip() or None, password_hash=hash_password(password), ) s.add(user) s.flush() for grp in inv.pre_assigned_groups: s.add(UserGroup(user_id=user.id, group_id=grp.id)) inv.consumed_at = _now() inv.consumed_by_user_id = user.id return user.id def revoke(invitation_id: uuid.UUID) -> bool: with session_scope() as s: inv = s.get(Invitation, invitation_id) if inv is None: return False if inv.revoked_at is not None or inv.consumed_at is not None: return False inv.revoked_at = _now() return True def list_active(*, limit: int = 100) -> list[Invitation]: with session_scope() as s: rows = s.scalars( select(Invitation) .where( Invitation.consumed_at.is_(None), Invitation.revoked_at.is_(None), Invitation.expires_at > _now(), ) .order_by(Invitation.created_at.desc()) .limit(limit) ).all() # detach so caller can read after session closes for r in rows: s.expunge(r) for g in r.pre_assigned_groups: s.expunge(g) return list(rows) def find_group_id_by_name(name: str) -> uuid.UUID | None: with session_scope() as s: gid = s.scalar( select(Group.id).where(Group.name == name, Group.deleted_at.is_(None)) ) return gid