Milestone 3
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
224
backend/app/services/auth.py
Normal file
224
backend/app/services/auth.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Auth domain logic: login, refresh rotation, logout, change_password.
|
||||
|
||||
Returns lightweight DTOs (dicts) — the API layer is responsible for HTTP shape.
|
||||
Raises plain `ValueError` / `LookupError` / `PermissionError` and lets the API
|
||||
layer translate them into HTTP statuses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.jwt_tokens import (
|
||||
REFRESH_TOKEN_TTL,
|
||||
decode_token,
|
||||
encode_token,
|
||||
generate_jti,
|
||||
)
|
||||
from app.core.security import (
|
||||
hash_opaque_token,
|
||||
hash_password,
|
||||
needs_rehash,
|
||||
verify_opaque_token,
|
||||
verify_password,
|
||||
)
|
||||
from app.db.session import session_scope
|
||||
from app.models.auth import RefreshToken, User
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
"""Base for auth-flow exceptions; HTTP layer maps to 401/403."""
|
||||
|
||||
|
||||
class InvalidCredentials(AuthError):
|
||||
pass
|
||||
|
||||
|
||||
class TokenRevoked(AuthError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenPair:
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
refresh_expires_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
# === Login ===================================================================
|
||||
|
||||
|
||||
def login(email: str, password: str) -> TokenPair:
|
||||
email_norm = email.strip().lower()
|
||||
with session_scope() as s:
|
||||
user = s.scalar(
|
||||
select(User).where(
|
||||
User.email == email_norm,
|
||||
User.deleted_at.is_(None),
|
||||
User.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
if user is None or not verify_password(user.password_hash, password):
|
||||
# Same error for "no such user" and "wrong password" — no account enumeration.
|
||||
raise InvalidCredentials("invalid credentials")
|
||||
|
||||
if needs_rehash(user.password_hash):
|
||||
user.password_hash = hash_password(password)
|
||||
|
||||
return _issue_token_pair(s, user.id)
|
||||
|
||||
|
||||
# === Refresh rotation ========================================================
|
||||
|
||||
|
||||
def refresh(raw_refresh_token: str) -> TokenPair:
|
||||
"""Validate the refresh token, revoke the old one, mint a new pair.
|
||||
|
||||
Detects token reuse: if a refresh token that has already been rotated is
|
||||
presented again, we revoke the entire chain (treat as compromise).
|
||||
"""
|
||||
try:
|
||||
claims = decode_token(raw_refresh_token, expected_type="refresh")
|
||||
except Exception as e:
|
||||
raise InvalidCredentials("invalid refresh token") from e
|
||||
|
||||
token_hash = hash_opaque_token(raw_refresh_token)
|
||||
|
||||
with session_scope() as s:
|
||||
rt = s.scalar(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.jti == claims.jti,
|
||||
RefreshToken.token_hash == token_hash,
|
||||
)
|
||||
)
|
||||
if rt is None:
|
||||
raise InvalidCredentials("refresh token not recognised")
|
||||
|
||||
if rt.revoked_at is not None:
|
||||
# Reuse of a revoked token → likely compromise. Cascade-revoke chain.
|
||||
_revoke_chain(s, rt)
|
||||
raise TokenRevoked("refresh token has been revoked")
|
||||
|
||||
if rt.expires_at <= _now():
|
||||
raise InvalidCredentials("refresh token expired")
|
||||
|
||||
# Rotate: mark old as revoked + replaced_by, mint new.
|
||||
new_pair = _issue_token_pair(s, rt.user_id)
|
||||
new_jti = decode_token(new_pair.refresh_token, expected_type="refresh").jti
|
||||
new_rt = s.scalar(select(RefreshToken).where(RefreshToken.jti == new_jti))
|
||||
rt.revoked_at = _now()
|
||||
rt.replaced_by_id = new_rt.id if new_rt else None
|
||||
return new_pair
|
||||
|
||||
|
||||
# === Logout ==================================================================
|
||||
|
||||
|
||||
def logout(raw_refresh_token: str) -> None:
|
||||
"""Revoke the refresh token. Idempotent — silently no-ops on bad tokens."""
|
||||
try:
|
||||
claims = decode_token(raw_refresh_token, expected_type="refresh")
|
||||
except Exception:
|
||||
return
|
||||
with session_scope() as s:
|
||||
rt = s.scalar(select(RefreshToken).where(RefreshToken.jti == claims.jti))
|
||||
if rt is not None and rt.revoked_at is None:
|
||||
rt.revoked_at = _now()
|
||||
|
||||
|
||||
def logout_all_for_user(user_id: uuid.UUID) -> int:
|
||||
"""Revoke every active refresh token for a user. Returns count revoked."""
|
||||
now = _now()
|
||||
with session_scope() as s:
|
||||
active = s.scalars(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
)
|
||||
).all()
|
||||
for rt in active:
|
||||
rt.revoked_at = now
|
||||
return len(active)
|
||||
|
||||
|
||||
# === Password change ========================================================
|
||||
|
||||
|
||||
def change_password(user_id: uuid.UUID, current: str, new: str) -> None:
|
||||
if len(new) < 8:
|
||||
raise ValueError("new password must be at least 8 characters")
|
||||
with session_scope() as s:
|
||||
user = s.get(User, user_id)
|
||||
if user is None or user.deleted_at is not None or not user.is_active:
|
||||
raise LookupError("user not found")
|
||||
if not verify_password(user.password_hash, current):
|
||||
raise InvalidCredentials("current password is incorrect")
|
||||
user.password_hash = hash_password(new)
|
||||
# Force re-login on every other device.
|
||||
logout_all_for_user(user_id)
|
||||
|
||||
|
||||
# === Helpers =================================================================
|
||||
|
||||
|
||||
def _issue_token_pair(s, user_id: uuid.UUID) -> TokenPair:
|
||||
"""Issue a fresh access + refresh pair. The refresh row is persisted."""
|
||||
access_jti = generate_jti()
|
||||
refresh_jti = generate_jti()
|
||||
access_token, _ = encode_token(user_id, "access", jti=access_jti)
|
||||
refresh_token, refresh_claims = encode_token(user_id, "refresh", jti=refresh_jti)
|
||||
|
||||
s.add(
|
||||
RefreshToken(
|
||||
user_id=user_id,
|
||||
jti=refresh_jti,
|
||||
token_hash=hash_opaque_token(refresh_token),
|
||||
issued_at=refresh_claims.iat,
|
||||
expires_at=refresh_claims.exp,
|
||||
)
|
||||
)
|
||||
s.flush() # ensure the row gets an id before we return
|
||||
|
||||
return TokenPair(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
refresh_expires_at=refresh_claims.exp,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _revoke_chain(s, rt: RefreshToken) -> None:
|
||||
"""When reuse is detected, revoke this token and its replacement chain."""
|
||||
seen: set[uuid.UUID] = set()
|
||||
cur: RefreshToken | None = rt
|
||||
while cur is not None and cur.id not in seen:
|
||||
seen.add(cur.id)
|
||||
if cur.revoked_at is None:
|
||||
cur.revoked_at = _now()
|
||||
if cur.replaced_by_id:
|
||||
cur = s.get(RefreshToken, cur.replaced_by_id)
|
||||
else:
|
||||
cur = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AuthError",
|
||||
"InvalidCredentials",
|
||||
"TokenRevoked",
|
||||
"TokenPair",
|
||||
"REFRESH_TOKEN_TTL",
|
||||
"login",
|
||||
"refresh",
|
||||
"logout",
|
||||
"logout_all_for_user",
|
||||
"change_password",
|
||||
]
|
||||
98
backend/app/services/bootstrap.py
Normal file
98
backend/app/services/bootstrap.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Initial bootstrap : seed `admin` / `redteam` / `blueteam` system groups + first admin.
|
||||
|
||||
The detailed permission seeding lives in M3 (`mitre.sync` etc.); for M2 we only
|
||||
need an `admin` group that effectively grants full access. We model that as an
|
||||
absent permission set + a special `is_system` flag on the group, plus the
|
||||
`@require_perm` decorator that bypasses checks for any user belonging to a
|
||||
system `admin` group. M3 will fill in the atomic permissions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.install_token import (
|
||||
mark_install_token_consumed,
|
||||
verify_install_token,
|
||||
)
|
||||
from app.core.security import hash_password
|
||||
from app.db.session import session_scope
|
||||
from app.models.auth import Group, User, UserGroup
|
||||
|
||||
ADMIN_GROUP_NAME = "admin"
|
||||
REDTEAM_GROUP_NAME = "redteam"
|
||||
BLUETEAM_GROUP_NAME = "blueteam"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BootstrapResult:
|
||||
user_id: uuid.UUID
|
||||
admin_group_id: uuid.UUID
|
||||
|
||||
|
||||
class BootstrapError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def ensure_system_groups() -> dict[str, uuid.UUID]:
|
||||
"""Create the three system groups if missing. Idempotent."""
|
||||
out: dict[str, uuid.UUID] = {}
|
||||
with session_scope() as s:
|
||||
for name, desc in (
|
||||
(ADMIN_GROUP_NAME, "Platform administrators — full access."),
|
||||
(REDTEAM_GROUP_NAME, "Red team operators."),
|
||||
(BLUETEAM_GROUP_NAME, "Blue team operators."),
|
||||
):
|
||||
grp = s.scalar(select(Group).where(Group.name == name, Group.is_system.is_(True)))
|
||||
if grp is None:
|
||||
grp = Group(name=name, description=desc, is_system=True)
|
||||
s.add(grp)
|
||||
s.flush()
|
||||
out[name] = grp.id
|
||||
return out
|
||||
|
||||
|
||||
def bootstrap_admin(
|
||||
*, install_token: str, email: str, password: str, display_name: str | None = None
|
||||
) -> BootstrapResult:
|
||||
"""Consume the install token, create the first admin user, attach to admin group."""
|
||||
if not verify_install_token(install_token):
|
||||
raise BootstrapError("invalid or already-consumed install token")
|
||||
if len(password) < 8:
|
||||
raise ValueError("password must be at least 8 characters")
|
||||
|
||||
email_norm = email.strip().lower()
|
||||
|
||||
# Re-check users count under transaction to avoid races.
|
||||
with session_scope() as s:
|
||||
if s.scalar(select(User.id).limit(1)) is not None:
|
||||
raise BootstrapError("setup already done — at least one user exists")
|
||||
|
||||
groups = ensure_system_groups()
|
||||
|
||||
with session_scope() as s:
|
||||
user = User(
|
||||
email=email_norm,
|
||||
display_name=(display_name or "").strip() or None,
|
||||
password_hash=hash_password(password),
|
||||
)
|
||||
s.add(user)
|
||||
s.flush()
|
||||
s.add(UserGroup(user_id=user.id, group_id=groups[ADMIN_GROUP_NAME]))
|
||||
admin_id = groups[ADMIN_GROUP_NAME]
|
||||
user_id = user.id
|
||||
|
||||
mark_install_token_consumed()
|
||||
|
||||
# Re-seed the permission catalogue + system-group bindings. This is called
|
||||
# at boot too, but on a fresh DB after `/diag/reset` the groups were just
|
||||
# recreated above and have no permissions yet — seeding here keeps the
|
||||
# bootstrap path self-contained.
|
||||
from app.services.permissions_seed import seed_all # noqa: PLC0415 — avoid import cycle
|
||||
|
||||
seed_all()
|
||||
|
||||
return BootstrapResult(user_id=user_id, admin_group_id=admin_id)
|
||||
210
backend/app/services/groups.py
Normal file
210
backend/app/services/groups.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Admin-side group management: CRUD + permission bindings.
|
||||
|
||||
System groups (`is_system=True`: admin, redteam, blueteam) cannot be renamed
|
||||
or deleted, but their permission bindings are seeded on boot and editable
|
||||
afterwards (e.g. an admin can broaden `redteam`).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.db.session import session_scope
|
||||
from app.models.auth import Group, GroupPermission, Permission, UserGroup
|
||||
from app.services.bootstrap import ADMIN_GROUP_NAME
|
||||
|
||||
|
||||
class GroupNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class GroupNameConflict(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SystemGroupProtected(Exception):
|
||||
"""Refusing to delete or rename a built-in system group."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroupView:
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
is_system: bool
|
||||
deleted_at: datetime | None
|
||||
members_count: int
|
||||
permissions: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _to_view(g: Group, members_count: int) -> GroupView:
|
||||
return GroupView(
|
||||
id=g.id,
|
||||
name=g.name,
|
||||
description=g.description,
|
||||
is_system=g.is_system,
|
||||
deleted_at=g.deleted_at,
|
||||
members_count=members_count,
|
||||
permissions=sorted(p.code for p in g.permissions),
|
||||
created_at=g.created_at,
|
||||
updated_at=g.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _members_counts(s, group_ids: list[uuid.UUID]) -> dict[uuid.UUID, int]:
|
||||
if not group_ids:
|
||||
return {}
|
||||
from app.models.auth import User as _U # local to avoid model cycles
|
||||
|
||||
rows = s.execute(
|
||||
select(UserGroup.group_id, func.count(UserGroup.user_id))
|
||||
.join(_U, _U.id == UserGroup.user_id)
|
||||
.where(UserGroup.group_id.in_(group_ids), _U.deleted_at.is_(None))
|
||||
.group_by(UserGroup.group_id)
|
||||
).all()
|
||||
return {gid: int(cnt) for gid, cnt in rows}
|
||||
|
||||
|
||||
def list_groups(*, include_deleted: bool = False) -> list[GroupView]:
|
||||
with session_scope() as s:
|
||||
stmt = select(Group).order_by(Group.is_system.desc(), Group.name.asc())
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Group.deleted_at.is_(None))
|
||||
rows = s.scalars(stmt).all()
|
||||
counts = _members_counts(s, [g.id for g in rows])
|
||||
return [_to_view(g, counts.get(g.id, 0)) for g in rows]
|
||||
|
||||
|
||||
def get_group(group_id: uuid.UUID) -> GroupView:
|
||||
with session_scope() as s:
|
||||
g = s.get(Group, group_id)
|
||||
if g is None or g.deleted_at is not None:
|
||||
raise GroupNotFound()
|
||||
counts = _members_counts(s, [g.id])
|
||||
return _to_view(g, counts.get(g.id, 0))
|
||||
|
||||
|
||||
def create_group(*, name: str, description: str | None) -> GroupView:
|
||||
name_norm = name.strip()
|
||||
if not name_norm:
|
||||
raise ValueError("name is required")
|
||||
with session_scope() as s:
|
||||
existing = s.scalar(
|
||||
select(Group).where(Group.name == name_norm, Group.deleted_at.is_(None))
|
||||
)
|
||||
if existing is not None:
|
||||
raise GroupNameConflict(f"group name {name_norm!r} already in use")
|
||||
g = Group(name=name_norm, description=(description or "").strip() or None, is_system=False)
|
||||
s.add(g)
|
||||
s.flush()
|
||||
return _to_view(g, 0)
|
||||
|
||||
|
||||
def update_group(
|
||||
group_id: uuid.UUID,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None | object = ...,
|
||||
) -> GroupView:
|
||||
with session_scope() as s:
|
||||
g = s.get(Group, group_id)
|
||||
if g is None or g.deleted_at is not None:
|
||||
raise GroupNotFound()
|
||||
if name is not None:
|
||||
name_norm = name.strip()
|
||||
if not name_norm:
|
||||
raise ValueError("name cannot be empty")
|
||||
if g.is_system and name_norm != g.name:
|
||||
raise SystemGroupProtected("system groups cannot be renamed")
|
||||
if name_norm != g.name:
|
||||
clash = s.scalar(
|
||||
select(Group).where(
|
||||
Group.name == name_norm,
|
||||
Group.deleted_at.is_(None),
|
||||
Group.id != g.id,
|
||||
)
|
||||
)
|
||||
if clash is not None:
|
||||
raise GroupNameConflict(f"group name {name_norm!r} already in use")
|
||||
g.name = name_norm
|
||||
if description is not ...:
|
||||
if description in (None, ""):
|
||||
g.description = None
|
||||
else:
|
||||
g.description = description.strip() or None
|
||||
|
||||
counts = _members_counts(s, [g.id])
|
||||
return _to_view(g, counts.get(g.id, 0))
|
||||
|
||||
|
||||
def soft_delete_group(group_id: uuid.UUID) -> None:
|
||||
with session_scope() as s:
|
||||
g = s.get(Group, group_id)
|
||||
if g is None or g.deleted_at is not None:
|
||||
raise GroupNotFound()
|
||||
if g.is_system:
|
||||
raise SystemGroupProtected("system groups cannot be deleted")
|
||||
g.deleted_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
def set_group_permissions(group_id: uuid.UUID, codes: list[str]) -> GroupView:
|
||||
"""Replace the group's permission set with the given codes (validated)."""
|
||||
desired_codes = set(codes)
|
||||
with session_scope() as s:
|
||||
g = s.get(Group, group_id)
|
||||
if g is None or g.deleted_at is not None:
|
||||
raise GroupNotFound()
|
||||
|
||||
# Preserve the invariant "the system `admin` group has every perm." The
|
||||
# decorator's admin bypass relies on `is_admin` (group membership), not
|
||||
# on the perm set, so a stripped admin group would still grant access —
|
||||
# but the listing would look misleading and a future refactor could
|
||||
# reasonably switch the bypass to a perm-based check.
|
||||
if g.is_system and g.name == ADMIN_GROUP_NAME:
|
||||
all_codes = {p.code for p in s.scalars(select(Permission)).all()}
|
||||
if desired_codes != all_codes:
|
||||
raise SystemGroupProtected(
|
||||
"the admin group must keep every permission"
|
||||
)
|
||||
|
||||
if desired_codes:
|
||||
perms = s.scalars(select(Permission).where(Permission.code.in_(desired_codes))).all()
|
||||
known = {p.code for p in perms}
|
||||
unknown = desired_codes - known
|
||||
if unknown:
|
||||
raise ValueError(f"unknown permission codes: {sorted(unknown)}")
|
||||
else:
|
||||
perms = []
|
||||
|
||||
current = {p.code: p for p in g.permissions}
|
||||
to_remove = set(current) - desired_codes
|
||||
to_add = desired_codes - set(current)
|
||||
|
||||
for code in to_remove:
|
||||
row = s.get(GroupPermission, (g.id, current[code].id))
|
||||
if row is not None:
|
||||
s.delete(row)
|
||||
for p in perms:
|
||||
if p.code in to_add:
|
||||
s.add(GroupPermission(group_id=g.id, permission_id=p.id))
|
||||
|
||||
s.flush()
|
||||
s.refresh(g)
|
||||
counts = _members_counts(s, [g.id])
|
||||
return _to_view(g, counts.get(g.id, 0))
|
||||
|
||||
|
||||
def list_permissions() -> list[dict]:
|
||||
"""Return the catalogue of all permissions known to the platform."""
|
||||
with session_scope() as s:
|
||||
rows = s.scalars(select(Permission).order_by(Permission.code.asc())).all()
|
||||
return [
|
||||
{"id": str(p.id), "code": p.code, "description": p.description}
|
||||
for p in rows
|
||||
]
|
||||
188
backend/app/services/invitations.py
Normal file
188
backend/app/services/invitations.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Invitation flow: admin issues a one-shot URL token, invitee accepts.
|
||||
|
||||
The raw token is shown to the admin once (returned by `create_invitation`)
|
||||
and never persisted — only its SHA-256 lives in the DB. Pre-assigned groups
|
||||
are attached at creation and applied at acceptance.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.security import (
|
||||
generate_opaque_token,
|
||||
hash_opaque_token,
|
||||
hash_password,
|
||||
)
|
||||
from app.db.session import session_scope
|
||||
from app.models.auth import Group, Invitation, InvitationGroup, User, UserGroup
|
||||
|
||||
INVITATION_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class InvitationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvitationExpired(InvitationError):
|
||||
pass
|
||||
|
||||
|
||||
class InvitationConsumed(InvitationError):
|
||||
pass
|
||||
|
||||
|
||||
class InvitationRevoked(InvitationError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvitationCreated:
|
||||
invitation_id: uuid.UUID
|
||||
raw_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvitationPreview:
|
||||
email_hint: str | None
|
||||
expires_at: datetime
|
||||
groups: list[str]
|
||||
is_valid: bool
|
||||
reason: str | None
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
def create_invitation(
|
||||
*,
|
||||
created_by_user_id: uuid.UUID,
|
||||
email_hint: str | None,
|
||||
group_ids: Iterable[uuid.UUID] = (),
|
||||
ttl: timedelta = INVITATION_TTL,
|
||||
) -> InvitationCreated:
|
||||
raw = generate_opaque_token()
|
||||
expires_at = _now() + ttl
|
||||
with session_scope() as s:
|
||||
inv = Invitation(
|
||||
token_hash=hash_opaque_token(raw),
|
||||
email_hint=email_hint.strip().lower() if email_hint else None,
|
||||
created_by_user_id=created_by_user_id,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
s.add(inv)
|
||||
s.flush()
|
||||
for gid in group_ids:
|
||||
s.add(InvitationGroup(invitation_id=inv.id, group_id=gid))
|
||||
return InvitationCreated(
|
||||
invitation_id=inv.id,
|
||||
raw_token=raw,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
def _load_by_token(s, raw_token: str) -> Invitation | None:
|
||||
return s.scalar(
|
||||
select(Invitation).where(Invitation.token_hash == hash_opaque_token(raw_token))
|
||||
)
|
||||
|
||||
|
||||
def preview(raw_token: str) -> InvitationPreview:
|
||||
with session_scope() as s:
|
||||
inv = _load_by_token(s, raw_token)
|
||||
if inv is None:
|
||||
return InvitationPreview(None, _now(), [], False, "not_found")
|
||||
groups = [g.name for g in inv.pre_assigned_groups]
|
||||
if inv.revoked_at is not None:
|
||||
return InvitationPreview(inv.email_hint, inv.expires_at, groups, False, "revoked")
|
||||
if inv.consumed_at is not None:
|
||||
return InvitationPreview(inv.email_hint, inv.expires_at, groups, False, "consumed")
|
||||
if inv.expires_at <= _now():
|
||||
return InvitationPreview(inv.email_hint, inv.expires_at, groups, False, "expired")
|
||||
return InvitationPreview(inv.email_hint, inv.expires_at, groups, True, None)
|
||||
|
||||
|
||||
def accept(raw_token: str, *, email: str, password: str, display_name: str | None) -> uuid.UUID:
|
||||
"""Create the user, attach pre-assigned groups, mark invitation consumed."""
|
||||
if len(password) < 8:
|
||||
raise ValueError("password must be at least 8 characters")
|
||||
|
||||
email_norm = email.strip().lower()
|
||||
with session_scope() as s:
|
||||
inv = _load_by_token(s, raw_token)
|
||||
if inv is None:
|
||||
raise InvitationError("invitation not found")
|
||||
if inv.revoked_at is not None:
|
||||
raise InvitationRevoked("invitation revoked")
|
||||
if inv.consumed_at is not None:
|
||||
raise InvitationConsumed("invitation already consumed")
|
||||
if inv.expires_at <= _now():
|
||||
raise InvitationExpired("invitation expired")
|
||||
|
||||
# Email must not be already in use among active users.
|
||||
existing = s.scalar(
|
||||
select(User).where(User.email == email_norm, User.deleted_at.is_(None))
|
||||
)
|
||||
if existing is not None:
|
||||
raise ValueError("email already in use")
|
||||
|
||||
user = User(
|
||||
email=email_norm,
|
||||
display_name=(display_name or "").strip() or None,
|
||||
password_hash=hash_password(password),
|
||||
)
|
||||
s.add(user)
|
||||
s.flush()
|
||||
|
||||
for grp in inv.pre_assigned_groups:
|
||||
s.add(UserGroup(user_id=user.id, group_id=grp.id))
|
||||
|
||||
inv.consumed_at = _now()
|
||||
inv.consumed_by_user_id = user.id
|
||||
return user.id
|
||||
|
||||
|
||||
def revoke(invitation_id: uuid.UUID) -> bool:
|
||||
with session_scope() as s:
|
||||
inv = s.get(Invitation, invitation_id)
|
||||
if inv is None:
|
||||
return False
|
||||
if inv.revoked_at is not None or inv.consumed_at is not None:
|
||||
return False
|
||||
inv.revoked_at = _now()
|
||||
return True
|
||||
|
||||
|
||||
def list_active(*, limit: int = 100) -> list[Invitation]:
|
||||
with session_scope() as s:
|
||||
rows = s.scalars(
|
||||
select(Invitation)
|
||||
.where(
|
||||
Invitation.consumed_at.is_(None),
|
||||
Invitation.revoked_at.is_(None),
|
||||
Invitation.expires_at > _now(),
|
||||
)
|
||||
.order_by(Invitation.created_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
# detach so caller can read after session closes
|
||||
for r in rows:
|
||||
s.expunge(r)
|
||||
for g in r.pre_assigned_groups:
|
||||
s.expunge(g)
|
||||
return list(rows)
|
||||
|
||||
|
||||
def find_group_id_by_name(name: str) -> uuid.UUID | None:
|
||||
with session_scope() as s:
|
||||
gid = s.scalar(
|
||||
select(Group.id).where(Group.name == name, Group.deleted_at.is_(None))
|
||||
)
|
||||
return gid
|
||||
179
backend/app/services/permissions_seed.py
Normal file
179
backend/app/services/permissions_seed.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Atomic permission catalogue + seed for the 3 default system groups.
|
||||
|
||||
Permissions follow the `<entity>.<action>` convention. They are the ground truth
|
||||
checked by `@require_perm`; admins bypass everything (cf. `auth_decorators.py`).
|
||||
|
||||
This module is the single place that lists every permission code shipped with
|
||||
the platform. To add a new perm in a future milestone:
|
||||
|
||||
1. Add an entry to `PERMISSION_CATALOGUE`.
|
||||
2. Decide which system group(s) should get it by default — edit
|
||||
`_default_redteam_perms()` / `_default_blueteam_perms()` if relevant
|
||||
(admin always gets everything, so no edit needed there).
|
||||
3. The next boot picks it up; existing groups are *upgraded* (perms added),
|
||||
never downgraded (we never remove perms from a system group, even if you
|
||||
trim the catalogue — that would be a destructive op disguised as a seed).
|
||||
|
||||
The seed is idempotent and safe to call on every boot.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db.session import session_scope
|
||||
from app.models.auth import Group, GroupPermission, Permission
|
||||
from app.services.bootstrap import (
|
||||
ADMIN_GROUP_NAME,
|
||||
BLUETEAM_GROUP_NAME,
|
||||
REDTEAM_GROUP_NAME,
|
||||
)
|
||||
|
||||
log = logging.getLogger("metamorph.permissions")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermissionDef:
|
||||
code: str
|
||||
description: str
|
||||
|
||||
|
||||
# === Catalogue ================================================================
|
||||
#
|
||||
# Order is presentation-only; the seed is idempotent. Grouped by family to keep
|
||||
# diffs reviewable and to mirror the admin UI grouping in M3.6.
|
||||
#
|
||||
PERMISSION_CATALOGUE: tuple[PermissionDef, ...] = (
|
||||
# users
|
||||
PermissionDef("user.read", "View users."),
|
||||
PermissionDef("user.create", "Create users (typically via invitation)."),
|
||||
PermissionDef("user.update", "Update user metadata (display name, locale, active flag)."),
|
||||
PermissionDef("user.delete", "Soft-delete a user."),
|
||||
# groups
|
||||
PermissionDef("group.read", "View groups and their permissions."),
|
||||
PermissionDef("group.create", "Create a custom group."),
|
||||
PermissionDef("group.update", "Edit a custom group (name, description, permissions, members)."),
|
||||
PermissionDef("group.delete", "Soft-delete a custom group."),
|
||||
# invitations
|
||||
PermissionDef("invitation.read", "View pending invitations."),
|
||||
PermissionDef("invitation.create", "Issue a new invitation URL."),
|
||||
PermissionDef("invitation.revoke", "Revoke an unconsumed invitation."),
|
||||
# test templates
|
||||
PermissionDef("test_template.read", "View the test-template catalogue."),
|
||||
PermissionDef("test_template.create", "Create a test template."),
|
||||
PermissionDef("test_template.update", "Edit a test template."),
|
||||
PermissionDef("test_template.delete", "Soft-delete a test template."),
|
||||
# scenario templates
|
||||
PermissionDef("scenario_template.read", "View the scenario-template catalogue."),
|
||||
PermissionDef("scenario_template.create", "Create a scenario template."),
|
||||
PermissionDef("scenario_template.update", "Edit a scenario template (and its ordered tests)."),
|
||||
PermissionDef("scenario_template.delete", "Soft-delete a scenario template."),
|
||||
# missions
|
||||
PermissionDef("mission.read", "View missions (server still filters by membership for non-admin)."),
|
||||
PermissionDef("mission.create", "Create a mission."),
|
||||
PermissionDef("mission.update", "Edit mission metadata, scenarios, members."),
|
||||
PermissionDef("mission.archive", "Move a mission to status=archived."),
|
||||
PermissionDef("mission.delete", "Soft-delete a mission."),
|
||||
PermissionDef("mission.write_red_fields", "Write red-side fields on a mission test."),
|
||||
PermissionDef("mission.write_blue_fields", "Write blue-side fields and upload evidence."),
|
||||
# detection levels + platform settings + MITRE sync
|
||||
PermissionDef("detection_level.read", "View the detection-level taxonomy."),
|
||||
PermissionDef("detection_level.update", "Edit the detection-level taxonomy."),
|
||||
PermissionDef("setting.read", "Read platform settings."),
|
||||
PermissionDef("setting.update", "Update platform settings."),
|
||||
PermissionDef("mitre.sync", "Trigger a MITRE ATT&CK Enterprise re-sync."),
|
||||
)
|
||||
|
||||
|
||||
def _default_redteam_perms() -> frozenset[str]:
|
||||
return frozenset(
|
||||
{
|
||||
# catalogue read-only
|
||||
"test_template.read",
|
||||
"scenario_template.read",
|
||||
# MITRE/detection refs
|
||||
"detection_level.read",
|
||||
# missions: full lifecycle on red side
|
||||
"mission.read",
|
||||
"mission.create",
|
||||
"mission.update",
|
||||
"mission.archive",
|
||||
"mission.write_red_fields",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _default_blueteam_perms() -> frozenset[str]:
|
||||
return frozenset(
|
||||
{
|
||||
"test_template.read",
|
||||
"scenario_template.read",
|
||||
"detection_level.read",
|
||||
"mission.read",
|
||||
"mission.write_blue_fields",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _all_perm_codes() -> frozenset[str]:
|
||||
return frozenset(p.code for p in PERMISSION_CATALOGUE)
|
||||
|
||||
|
||||
def seed_permissions() -> dict[str, int]:
|
||||
"""Insert any missing permissions. Returns counts: `created`, `total`."""
|
||||
created = 0
|
||||
with session_scope() as s:
|
||||
existing_codes = set(s.scalars(select(Permission.code)).all())
|
||||
for p in PERMISSION_CATALOGUE:
|
||||
if p.code in existing_codes:
|
||||
continue
|
||||
s.add(Permission(code=p.code, description=p.description))
|
||||
created += 1
|
||||
return {"created": created, "total": len(PERMISSION_CATALOGUE)}
|
||||
|
||||
|
||||
def _assign_perms_to_group(group_name: str, codes: frozenset[str]) -> int:
|
||||
"""Attach the named perms to the given system group. Returns added count.
|
||||
|
||||
We never *remove* perms from a system group here — the seed is additive.
|
||||
Admins/operators who want to revoke must do so explicitly via the UI/API.
|
||||
"""
|
||||
if not codes:
|
||||
return 0
|
||||
added = 0
|
||||
with session_scope() as s:
|
||||
group = s.scalar(select(Group).where(Group.name == group_name, Group.is_system.is_(True)))
|
||||
if group is None:
|
||||
raise RuntimeError(f"system group {group_name!r} missing — call ensure_system_groups() first")
|
||||
|
||||
existing_codes = {p.code for p in group.permissions}
|
||||
perms = s.scalars(select(Permission).where(Permission.code.in_(codes))).all()
|
||||
for p in perms:
|
||||
if p.code in existing_codes:
|
||||
continue
|
||||
s.add(GroupPermission(group_id=group.id, permission_id=p.id))
|
||||
added += 1
|
||||
return added
|
||||
|
||||
|
||||
def seed_default_group_permissions() -> dict[str, int]:
|
||||
"""Bind the catalogue to the 3 default groups. Idempotent + additive."""
|
||||
counts: dict[str, int] = {}
|
||||
counts[ADMIN_GROUP_NAME] = _assign_perms_to_group(ADMIN_GROUP_NAME, _all_perm_codes())
|
||||
counts[REDTEAM_GROUP_NAME] = _assign_perms_to_group(REDTEAM_GROUP_NAME, _default_redteam_perms())
|
||||
counts[BLUETEAM_GROUP_NAME] = _assign_perms_to_group(BLUETEAM_GROUP_NAME, _default_blueteam_perms())
|
||||
return counts
|
||||
|
||||
|
||||
def seed_all() -> dict[str, dict[str, int]]:
|
||||
"""One-shot helper: catalogue + default group bindings."""
|
||||
perms = seed_permissions()
|
||||
bindings = seed_default_group_permissions()
|
||||
log.info(
|
||||
"metamorph.permissions.seeded",
|
||||
extra={"perms_created": perms["created"], "perms_total": perms["total"], "bindings": bindings},
|
||||
)
|
||||
return {"permissions": perms, "bindings": bindings}
|
||||
204
backend/app/services/users.py
Normal file
204
backend/app/services/users.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Admin-side user management: list, get, update, soft-delete, assign groups.
|
||||
|
||||
Self-service updates (locale, password, display_name) live in
|
||||
`services.auth` — this module is for admin operations on other users.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
|
||||
from app.db.session import session_scope
|
||||
from app.models.auth import Group, User, UserGroup
|
||||
from app.services.bootstrap import ADMIN_GROUP_NAME
|
||||
|
||||
|
||||
class UserNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LastAdminProtected(Exception):
|
||||
"""Refusing to strip admin from the last active admin."""
|
||||
|
||||
|
||||
class SystemGroupProtected(Exception):
|
||||
"""Refusing to delete or rename a built-in system group."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UserView:
|
||||
id: uuid.UUID
|
||||
email: str
|
||||
display_name: str | None
|
||||
locale: str
|
||||
is_active: bool
|
||||
deleted_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
groups: list[tuple[uuid.UUID, str]]
|
||||
|
||||
|
||||
def _to_view(u: User) -> UserView:
|
||||
return UserView(
|
||||
id=u.id,
|
||||
email=u.email,
|
||||
display_name=u.display_name,
|
||||
locale=u.locale,
|
||||
is_active=u.is_active,
|
||||
deleted_at=u.deleted_at,
|
||||
created_at=u.created_at,
|
||||
updated_at=u.updated_at,
|
||||
groups=[(g.id, g.name) for g in u.groups if g.deleted_at is None],
|
||||
)
|
||||
|
||||
|
||||
def list_users(
|
||||
*,
|
||||
q: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
include_deleted: bool = False,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[UserView], int]:
|
||||
"""Return (rows, total_count) with case-insensitive search on email + display_name."""
|
||||
with session_scope() as s:
|
||||
stmt = select(User)
|
||||
count_stmt = select(func.count()).select_from(User)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(User.deleted_at.is_(None))
|
||||
count_stmt = count_stmt.where(User.deleted_at.is_(None))
|
||||
if is_active is not None:
|
||||
stmt = stmt.where(User.is_active.is_(is_active))
|
||||
count_stmt = count_stmt.where(User.is_active.is_(is_active))
|
||||
if q:
|
||||
like = f"%{q.lower()}%"
|
||||
stmt = stmt.where(
|
||||
or_(func.lower(User.email).like(like), func.lower(User.display_name).like(like))
|
||||
)
|
||||
count_stmt = count_stmt.where(
|
||||
or_(func.lower(User.email).like(like), func.lower(User.display_name).like(like))
|
||||
)
|
||||
stmt = stmt.order_by(User.email.asc()).limit(limit).offset(offset)
|
||||
rows = s.scalars(stmt).all()
|
||||
total = int(s.scalar(count_stmt) or 0)
|
||||
views = [_to_view(u) for u in rows]
|
||||
return views, total
|
||||
|
||||
|
||||
def get_user(user_id: uuid.UUID, *, include_deleted: bool = False) -> UserView:
|
||||
with session_scope() as s:
|
||||
u = s.get(User, user_id)
|
||||
if u is None or (u.deleted_at is not None and not include_deleted):
|
||||
raise UserNotFound()
|
||||
return _to_view(u)
|
||||
|
||||
|
||||
def update_user(
|
||||
user_id: uuid.UUID,
|
||||
*,
|
||||
display_name: str | None | object = ...,
|
||||
locale: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
) -> UserView:
|
||||
"""Partial update. Pass display_name=None to clear; omit to leave unchanged."""
|
||||
with session_scope() as s:
|
||||
u = s.get(User, user_id)
|
||||
if u is None or u.deleted_at is not None:
|
||||
raise UserNotFound()
|
||||
if display_name is not ...:
|
||||
if display_name in (None, ""):
|
||||
u.display_name = None
|
||||
else:
|
||||
u.display_name = display_name.strip() or None
|
||||
if locale is not None:
|
||||
u.locale = locale
|
||||
if is_active is not None:
|
||||
# If deactivating the last active admin, refuse.
|
||||
if not is_active and _is_last_active_admin(s, u):
|
||||
raise LastAdminProtected("cannot deactivate the last active admin")
|
||||
u.is_active = is_active
|
||||
return _to_view(u)
|
||||
|
||||
|
||||
def soft_delete_user(user_id: uuid.UUID) -> None:
|
||||
with session_scope() as s:
|
||||
u = s.get(User, user_id)
|
||||
if u is None or u.deleted_at is not None:
|
||||
raise UserNotFound()
|
||||
if _is_last_active_admin(s, u):
|
||||
raise LastAdminProtected("cannot delete the last active admin")
|
||||
u.deleted_at = datetime.now(tz=timezone.utc)
|
||||
u.is_active = False
|
||||
|
||||
|
||||
def set_user_groups(user_id: uuid.UUID, group_ids: Iterable[uuid.UUID]) -> UserView:
|
||||
"""Replace the user's group memberships with the given set."""
|
||||
desired = set(group_ids)
|
||||
with session_scope() as s:
|
||||
u = s.get(User, user_id)
|
||||
if u is None or u.deleted_at is not None:
|
||||
raise UserNotFound()
|
||||
|
||||
# Resolve admin group id once.
|
||||
admin_group_id = s.scalar(
|
||||
select(Group.id).where(Group.name == ADMIN_GROUP_NAME, Group.is_system.is_(True))
|
||||
)
|
||||
is_currently_admin = admin_group_id in {g.id for g in u.groups}
|
||||
will_be_admin = admin_group_id in desired
|
||||
if is_currently_admin and not will_be_admin and _is_last_active_admin(s, u):
|
||||
raise LastAdminProtected("cannot remove admin from the last active admin")
|
||||
|
||||
# Refuse silently for unknown groups: validate first.
|
||||
if desired:
|
||||
known = set(
|
||||
s.scalars(
|
||||
select(Group.id).where(Group.id.in_(desired), Group.deleted_at.is_(None))
|
||||
).all()
|
||||
)
|
||||
unknown = desired - known
|
||||
if unknown:
|
||||
raise ValueError(f"unknown groups: {sorted(map(str, unknown))}")
|
||||
|
||||
current = {g.id for g in u.groups}
|
||||
to_add = desired - current
|
||||
to_remove = current - desired
|
||||
|
||||
for gid in to_remove:
|
||||
row = s.get(UserGroup, (u.id, gid))
|
||||
if row is not None:
|
||||
s.delete(row)
|
||||
for gid in to_add:
|
||||
s.add(UserGroup(user_id=u.id, group_id=gid))
|
||||
|
||||
s.flush()
|
||||
s.refresh(u)
|
||||
return _to_view(u)
|
||||
|
||||
|
||||
def _is_last_active_admin(s, user: User) -> bool:
|
||||
"""True when `user` is currently in the admin system group and removing/blocking
|
||||
them would leave the platform with zero active admins."""
|
||||
admin_group_id = s.scalar(
|
||||
select(Group.id).where(Group.name == ADMIN_GROUP_NAME, Group.is_system.is_(True))
|
||||
)
|
||||
if admin_group_id is None:
|
||||
return False
|
||||
if admin_group_id not in {g.id for g in user.groups}:
|
||||
return False
|
||||
other_admins = s.scalar(
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.join(UserGroup, UserGroup.user_id == User.id)
|
||||
.where(
|
||||
UserGroup.group_id == admin_group_id,
|
||||
User.id != user.id,
|
||||
User.deleted_at.is_(None),
|
||||
User.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
return int(other_admins or 0) == 0
|
||||
Reference in New Issue
Block a user