feature/m5-templates #2

Merged
knacky merged 6 commits from feature/m5-templates into main 2026-05-13 09:19:54 +00:00
7 changed files with 1600 additions and 0 deletions
Showing only changes of commit b8fd99a5f4 - Show all commits

View File

@@ -73,6 +73,17 @@ def reset_test_state():
"user_groups, settings, groups RESTART IDENTITY CASCADE" "user_groups, settings, groups RESTART IDENTITY CASCADE"
) )
) )
# Template catalogue reset (M5). The MITRE truncate below cascades to
# the polymorphic tag join, but the template rows themselves must be
# wiped first because `scenario_template_tests.test_template_id` is
# ON DELETE RESTRICT.
conn.execute(
text(
"TRUNCATE scenario_template_tests, scenario_templates, "
"test_template_mitre_tags, test_templates "
"RESTART IDENTITY CASCADE"
)
)
# MITRE reference reset — kept in sync with `settings` so a freshly # MITRE reference reset — kept in sync with `settings` so a freshly
# reset stack has `GET /mitre/status` and `GET /mitre/tactics` agree # reset stack has `GET /mitre/status` and `GET /mitre/tactics` agree
# ("no data, no last_sync"). The e2e suite re-syncs via /mitre/sync # ("no data, no last_sync"). The e2e suite re-syncs via /mitre/sync

View File

@@ -0,0 +1,208 @@
"""Scenario-template CRUD + reorder endpoints.
`PUT /<id>/tests` is the reorder/replace endpoint — it takes the full ordered
list and rewrites the join rows. There's no partial mutation API for the test
list: the wire contract is simpler and the client (drag-and-drop) already
holds the full ordering.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from flask import Blueprint, jsonify, request
from pydantic import BaseModel, Field, ValidationError
from app.core.auth_decorators import require_auth, require_perm
from app.services import scenario_templates as svc
bp = Blueprint("scenario_templates", __name__, url_prefix="/scenario-templates")
log = logging.getLogger("metamorph.api.scenario_templates")
class CreateScenarioPayload(BaseModel):
name: str = Field(min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=4000)
test_template_ids: list[uuid.UUID] = Field(default_factory=list, max_length=512)
model_config = {"extra": "forbid"}
class UpdateScenarioPayload(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=4000)
model_config = {"extra": "forbid"}
class SetTestsPayload(BaseModel):
test_template_ids: list[uuid.UUID] = Field(default_factory=list, max_length=512)
model_config = {"extra": "forbid"}
def _serialize(sc: svc.ScenarioTemplateView) -> dict[str, Any]:
return {
"id": str(sc.id),
"name": sc.name,
"description": sc.description,
"tests": [
{
"position": t.position,
"test_template_id": str(t.test_template_id),
"test_template_name": t.test_template_name,
"test_template_deleted": t.test_template_deleted,
}
for t in sc.tests
],
"tests_count": sc.tests_count,
"deleted_at": sc.deleted_at.isoformat() if sc.deleted_at else None,
"created_at": sc.created_at.isoformat(),
"updated_at": sc.updated_at.isoformat(),
}
def _parse_uuid_or_400(raw: str):
try:
return uuid.UUID(raw)
except ValueError:
return None
def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]:
try:
limit = int(request.args.get("limit", "100"))
offset = int(request.args.get("offset", "0"))
except ValueError:
return None, (400, "invalid_pagination")
return max(1, min(limit, 500)), max(0, offset)
@bp.get("")
@require_auth
@require_perm("scenario_template.read")
def list_scenario_templates():
paging = _pagination_args()
if paging[0] is None:
return jsonify({"error": paging[1][1]}), paging[1][0]
limit, offset = paging
q = request.args.get("q") or None
include_deleted = request.args.get("include_deleted", "false").lower() == "true"
items, total = svc.list_scenario_templates(
q=q, include_deleted=include_deleted, limit=limit, offset=offset
)
return jsonify(
{
"items": [_serialize(it) for it in items],
"total": total,
"limit": limit,
"offset": offset,
}
)
@bp.get("/<scenario_id>")
@require_auth
@require_perm("scenario_template.read")
def get_scenario_template(scenario_id: str):
sid = _parse_uuid_or_400(scenario_id)
if sid is None:
return jsonify({"error": "invalid_id"}), 400
include_deleted = request.args.get("include_deleted", "false").lower() == "true"
try:
view = svc.get_scenario_template(sid, include_deleted=include_deleted)
except svc.ScenarioTemplateNotFound:
return jsonify({"error": "not_found"}), 404
return jsonify(_serialize(view))
@bp.post("")
@require_auth
@require_perm("scenario_template.create")
def create_scenario_template():
try:
payload = CreateScenarioPayload.model_validate(request.get_json(silent=True) or {})
except ValidationError as e:
return jsonify({"error": "invalid_request", "details": e.errors()}), 400
try:
view = svc.create_scenario_template(
name=payload.name,
description=payload.description,
test_template_ids=payload.test_template_ids,
)
except svc.UnknownTestTemplate as e:
return jsonify({"error": "unknown_test_template", "message": str(e)}), 400
except ValueError as e:
return jsonify({"error": "invalid_request", "message": str(e)}), 400
log.info(
"metamorph.scenario_template.created",
extra={"id": str(view.id), "tests": len(view.tests)},
)
return jsonify(_serialize(view)), 201
@bp.patch("/<scenario_id>")
@require_auth
@require_perm("scenario_template.update")
def update_scenario_template(scenario_id: str):
sid = _parse_uuid_or_400(scenario_id)
if sid is None:
return jsonify({"error": "invalid_id"}), 400
raw = request.get_json(silent=True) or {}
try:
payload = UpdateScenarioPayload.model_validate(raw)
except ValidationError as e:
return jsonify({"error": "invalid_request", "details": e.errors()}), 400
kwargs: dict[str, Any] = {}
if "name" in raw:
kwargs["name"] = payload.name
if "description" in raw:
kwargs["description"] = payload.description
try:
view = svc.update_scenario_template(sid, **kwargs)
except svc.ScenarioTemplateNotFound:
return jsonify({"error": "not_found"}), 404
except ValueError as e:
return jsonify({"error": "invalid_request", "message": str(e)}), 400
return jsonify(_serialize(view))
@bp.put("/<scenario_id>/tests")
@require_auth
@require_perm("scenario_template.update")
def set_scenario_tests(scenario_id: str):
sid = _parse_uuid_or_400(scenario_id)
if sid is None:
return jsonify({"error": "invalid_id"}), 400
try:
payload = SetTestsPayload.model_validate(request.get_json(silent=True) or {})
except ValidationError as e:
return jsonify({"error": "invalid_request", "details": e.errors()}), 400
try:
view = svc.set_scenario_tests(sid, payload.test_template_ids)
except svc.ScenarioTemplateNotFound:
return jsonify({"error": "not_found"}), 404
except svc.UnknownTestTemplate as e:
return jsonify({"error": "unknown_test_template", "message": str(e)}), 400
log.info(
"metamorph.scenario_template.tests_set",
extra={"id": str(sid), "tests": len(view.tests)},
)
return jsonify(_serialize(view))
@bp.delete("/<scenario_id>")
@require_auth
@require_perm("scenario_template.delete")
def soft_delete_scenario_template(scenario_id: str):
sid = _parse_uuid_or_400(scenario_id)
if sid is None:
return jsonify({"error": "invalid_id"}), 400
try:
svc.soft_delete_scenario_template(sid)
except svc.ScenarioTemplateNotFound:
return jsonify({"error": "not_found"}), 404
log.info("metamorph.scenario_template.soft_deleted", extra={"id": str(sid)})
return jsonify({"ok": True})

View File

@@ -0,0 +1,250 @@
"""Test-template CRUD endpoints.
Reads gated by `test_template.read`. Writes gated by `test_template.{create,
update,delete}`. Service layer handles all DB work; this module only validates
the wire payload and shapes the JSON response.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from flask import Blueprint, jsonify, request
from pydantic import BaseModel, Field, ValidationError
from app.core.auth_decorators import require_auth, require_perm
from app.services import test_templates as svc
bp = Blueprint("test_templates", __name__, url_prefix="/test-templates")
log = logging.getLogger("metamorph.api.test_templates")
# === Payload schemas ==========================================================
class MitreTagIn(BaseModel):
kind: str = Field(min_length=1)
external_id: str = Field(min_length=1, max_length=16)
model_config = {"extra": "forbid"}
class CreateTestTemplatePayload(BaseModel):
name: str = Field(min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=4000)
objective: str | None = Field(default=None, max_length=4000)
procedure_md: str | None = Field(default=None, max_length=32_000)
prerequisites_md: str | None = Field(default=None, max_length=32_000)
expected_result_red_md: str | None = Field(default=None, max_length=32_000)
expected_detection_blue_md: str | None = Field(default=None, max_length=32_000)
opsec_level: str = Field(default="medium")
tags: list[str] = Field(default_factory=list, max_length=64)
expected_iocs: list[str] = Field(default_factory=list, max_length=128)
mitre_tags: list[MitreTagIn] = Field(default_factory=list, max_length=64)
model_config = {"extra": "forbid"}
class UpdateTestTemplatePayload(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=4000)
objective: str | None = Field(default=None, max_length=4000)
procedure_md: str | None = Field(default=None, max_length=32_000)
prerequisites_md: str | None = Field(default=None, max_length=32_000)
expected_result_red_md: str | None = Field(default=None, max_length=32_000)
expected_detection_blue_md: str | None = Field(default=None, max_length=32_000)
opsec_level: str | None = None
tags: list[str] | None = Field(default=None, max_length=64)
expected_iocs: list[str] | None = Field(default=None, max_length=128)
mitre_tags: list[MitreTagIn] | None = Field(default=None, max_length=64)
model_config = {"extra": "forbid"}
# === Serializers ==============================================================
def _serialize(t: svc.TestTemplateView) -> dict[str, Any]:
return {
"id": str(t.id),
"name": t.name,
"description": t.description,
"objective": t.objective,
"procedure_md": t.procedure_md,
"prerequisites_md": t.prerequisites_md,
"expected_result_red_md": t.expected_result_red_md,
"expected_detection_blue_md": t.expected_detection_blue_md,
"opsec_level": t.opsec_level,
"tags": list(t.tags),
"expected_iocs": list(t.expected_iocs),
"mitre_tags": [
{"kind": tag.kind, "external_id": tag.external_id, "name": tag.name, "url": tag.url}
for tag in t.mitre_tags
],
"deleted_at": t.deleted_at.isoformat() if t.deleted_at else None,
"created_at": t.created_at.isoformat(),
"updated_at": t.updated_at.isoformat(),
}
def _parse_uuid_or_400(raw: str):
try:
return uuid.UUID(raw)
except ValueError:
return None
def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]:
try:
limit = int(request.args.get("limit", "100"))
offset = int(request.args.get("offset", "0"))
except ValueError:
return None, (400, "invalid_pagination")
return max(1, min(limit, 500)), max(0, offset)
# === Endpoints ================================================================
@bp.get("")
@require_auth
@require_perm("test_template.read")
def list_test_templates():
paging = _pagination_args()
if paging[0] is None:
return jsonify({"error": paging[1][1]}), paging[1][0]
limit, offset = paging
q = request.args.get("q") or None
tactic = request.args.get("tactic") or None
technique = request.args.get("technique") or None
subtechnique = request.args.get("subtechnique") or None
opsec_level = request.args.get("opsec") or None
tag = request.args.get("tag") or None
include_deleted = request.args.get("include_deleted", "false").lower() == "true"
try:
items, total = svc.list_test_templates(
q=q,
tactic=tactic,
technique=technique,
subtechnique=subtechnique,
opsec_level=opsec_level,
tag=tag,
include_deleted=include_deleted,
limit=limit,
offset=offset,
)
except ValueError as e:
return jsonify({"error": "invalid_request", "message": str(e)}), 400
return jsonify(
{
"items": [_serialize(it) for it in items],
"total": total,
"limit": limit,
"offset": offset,
}
)
@bp.get("/<template_id>")
@require_auth
@require_perm("test_template.read")
def get_test_template(template_id: str):
tid = _parse_uuid_or_400(template_id)
if tid is None:
return jsonify({"error": "invalid_id"}), 400
include_deleted = request.args.get("include_deleted", "false").lower() == "true"
try:
view = svc.get_test_template(tid, include_deleted=include_deleted)
except svc.TestTemplateNotFound:
return jsonify({"error": "not_found"}), 404
return jsonify(_serialize(view))
@bp.post("")
@require_auth
@require_perm("test_template.create")
def create_test_template():
try:
payload = CreateTestTemplatePayload.model_validate(request.get_json(silent=True) or {})
except ValidationError as e:
return jsonify({"error": "invalid_request", "details": e.errors()}), 400
try:
view = svc.create_test_template(
name=payload.name,
description=payload.description,
objective=payload.objective,
procedure_md=payload.procedure_md,
prerequisites_md=payload.prerequisites_md,
expected_result_red_md=payload.expected_result_red_md,
expected_detection_blue_md=payload.expected_detection_blue_md,
opsec_level=payload.opsec_level,
tags=payload.tags,
expected_iocs=payload.expected_iocs,
mitre_tags=[svc.MitreTagRef(kind=t.kind, external_id=t.external_id) for t in payload.mitre_tags],
)
except svc.UnknownMitreTag as e:
return jsonify({"error": "unknown_mitre_tag", "message": str(e)}), 400
except ValueError as e:
return jsonify({"error": "invalid_request", "message": str(e)}), 400
log.info(
"metamorph.test_template.created",
extra={"id": str(view.id), "template_name": view.name},
)
return jsonify(_serialize(view)), 201
@bp.put("/<template_id>")
@require_auth
@require_perm("test_template.update")
def update_test_template(template_id: str):
tid = _parse_uuid_or_400(template_id)
if tid is None:
return jsonify({"error": "invalid_id"}), 400
raw = request.get_json(silent=True) or {}
try:
payload = UpdateTestTemplatePayload.model_validate(raw)
except ValidationError as e:
return jsonify({"error": "invalid_request", "details": e.errors()}), 400
# Only forward keys actually present in the body — model_validate leaves
# missing fields as None and we can't distinguish "explicitly null" from
# "omitted". The set of keys in `raw` is the wire-level intent.
kwargs: dict[str, Any] = {}
for field_name in (
"name", "description", "objective", "procedure_md", "prerequisites_md",
"expected_result_red_md", "expected_detection_blue_md",
"opsec_level", "tags", "expected_iocs",
):
if field_name in raw:
kwargs[field_name] = getattr(payload, field_name)
if "mitre_tags" in raw:
kwargs["mitre_tags"] = (
[svc.MitreTagRef(kind=t.kind, external_id=t.external_id) for t in (payload.mitre_tags or [])]
)
try:
view = svc.update_test_template(tid, **kwargs)
except svc.TestTemplateNotFound:
return jsonify({"error": "not_found"}), 404
except svc.UnknownMitreTag as e:
return jsonify({"error": "unknown_mitre_tag", "message": str(e)}), 400
except ValueError as e:
return jsonify({"error": "invalid_request", "message": str(e)}), 400
log.info("metamorph.test_template.updated", extra={"id": str(tid), "fields": sorted(kwargs.keys())})
return jsonify(_serialize(view))
@bp.delete("/<template_id>")
@require_auth
@require_perm("test_template.delete")
def soft_delete_test_template(template_id: str):
tid = _parse_uuid_or_400(template_id)
if tid is None:
return jsonify({"error": "invalid_id"}), 400
try:
svc.soft_delete_test_template(tid)
except svc.TestTemplateNotFound:
return jsonify({"error": "not_found"}), 404
log.info("metamorph.test_template.soft_deleted", extra={"id": str(tid)})
return jsonify({"ok": True})

View File

@@ -11,7 +11,9 @@ from app.api.health import bp as health_bp
from app.api.invitations import bp as invitations_bp from app.api.invitations import bp as invitations_bp
from app.api.mitre import bp as mitre_bp from app.api.mitre import bp as mitre_bp
from app.api.permissions import bp as permissions_bp from app.api.permissions import bp as permissions_bp
from app.api.scenario_templates import bp as scenario_templates_bp
from app.api.setup import bp as setup_bp from app.api.setup import bp as setup_bp
from app.api.test_templates import bp as test_templates_bp
from app.api.users import bp as users_bp from app.api.users import bp as users_bp
bp = Blueprint("v1", __name__, url_prefix="/api/v1") bp = Blueprint("v1", __name__, url_prefix="/api/v1")
@@ -24,3 +26,5 @@ bp.register_blueprint(users_bp)
bp.register_blueprint(groups_bp) bp.register_blueprint(groups_bp)
bp.register_blueprint(permissions_bp) bp.register_blueprint(permissions_bp)
bp.register_blueprint(mitre_bp) bp.register_blueprint(mitre_bp)
bp.register_blueprint(test_templates_bp)
bp.register_blueprint(scenario_templates_bp)

View File

@@ -0,0 +1,240 @@
"""CRUD service for `scenario_templates` + their ordered test list.
Re-ordering is implemented as **full delete + re-insert** of the
`scenario_template_tests` rows. The UNIQUE (scenario_template_id, position)
constraint makes any naive position-swap fail mid-transaction; wiping the set
then re-inserting at positions 0..N-1 keeps the operation atomic and obvious.
The same test_template may legitimately appear multiple times in a scenario
(chained operations), so we key on `(scenario_id, position)`, not
`(scenario_id, test_template_id)`.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, selectinload
_UNSET: Any = object()
from app.db.session import session_scope
from app.models.template import (
ScenarioTemplate,
ScenarioTemplateTest,
TestTemplate,
)
class ScenarioTemplateNotFound(Exception):
pass
class UnknownTestTemplate(Exception):
"""Raised when a scenario references a non-existent or soft-deleted test."""
@dataclass(frozen=True)
class ScenarioTestView:
position: int
test_template_id: uuid.UUID
test_template_name: str
test_template_deleted: bool
@dataclass(frozen=True)
class ScenarioTemplateView:
id: uuid.UUID
name: str
description: str | None
tests: list[ScenarioTestView]
tests_count: int
deleted_at: datetime | None
created_at: datetime
updated_at: datetime
def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView:
test_ids = [link.test_template_id for link in sc.tests]
name_by_id: dict[uuid.UUID, tuple[str, bool]] = {}
if test_ids:
rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all()
for row in rows:
name_by_id[row.id] = (row.name, row.deleted_at is not None)
tests = [
ScenarioTestView(
position=link.position,
test_template_id=link.test_template_id,
test_template_name=name_by_id.get(link.test_template_id, ("<missing>", True))[0],
test_template_deleted=name_by_id.get(link.test_template_id, ("<missing>", True))[1],
)
for link in sc.tests
]
return ScenarioTemplateView(
id=sc.id,
name=sc.name,
description=sc.description,
tests=tests,
tests_count=len(tests),
deleted_at=sc.deleted_at,
created_at=sc.created_at,
updated_at=sc.updated_at,
)
def _base_query():
return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests))
def list_scenario_templates(
*,
q: str | None = None,
include_deleted: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[ScenarioTemplateView], int]:
with session_scope() as s:
stmt = _base_query().order_by(ScenarioTemplate.name.asc())
count_stmt = select(func.count()).select_from(ScenarioTemplate)
if not include_deleted:
stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None))
count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None))
if q:
like = f"%{q.lower()}%"
cond = or_(
func.lower(ScenarioTemplate.name).like(like),
func.lower(ScenarioTemplate.description).like(like),
)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
total = s.scalar(count_stmt) or 0
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
return [_to_view(s, sc) for sc in rows], int(total)
def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView:
with session_scope() as s:
sc = s.get(ScenarioTemplate, scenario_id)
if sc is None:
raise ScenarioTemplateNotFound()
if sc.deleted_at is not None and not include_deleted:
raise ScenarioTemplateNotFound()
return _to_view(s, sc)
def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None:
"""Reject unknown or soft-deleted test_template ids before persisting."""
if not ids:
return
found = s.execute(
select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids))
).all()
known = {row.id for row in found}
deleted = {row.id for row in found if row.deleted_at is not None}
missing = set(ids) - known
if missing:
raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}")
if deleted:
raise UnknownTestTemplate(
f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}"
)
def _opt_str(value: str | None) -> str | None:
if value is None:
return None
s = value.strip()
return s or None
def create_scenario_template(
*,
name: str,
description: str | None = None,
test_template_ids: list[uuid.UUID] | None = None,
) -> ScenarioTemplateView:
name_norm = (name or "").strip()
if not name_norm:
raise ValueError("name is required")
ids = list(test_template_ids or [])
with session_scope() as s:
_validate_test_ids(s, ids)
sc = ScenarioTemplate(
name=name_norm,
description=_opt_str(description),
)
s.add(sc)
s.flush()
for position, tid in enumerate(ids):
s.add(
ScenarioTemplateTest(
scenario_template_id=sc.id,
test_template_id=tid,
position=position,
)
)
s.flush()
s.refresh(sc)
return _to_view(s, sc)
def update_scenario_template(
scenario_id: uuid.UUID,
*,
name: str | None = None,
description: Any = _UNSET,
) -> ScenarioTemplateView:
with session_scope() as s:
sc = s.get(ScenarioTemplate, scenario_id)
if sc is None or sc.deleted_at is not None:
raise ScenarioTemplateNotFound()
if name is not None:
n = name.strip()
if not n:
raise ValueError("name cannot be empty")
sc.name = n
if description is not _UNSET:
sc.description = _opt_str(description)
s.flush()
s.refresh(sc)
return _to_view(s, sc)
def set_scenario_tests(
scenario_id: uuid.UUID,
test_template_ids: list[uuid.UUID],
) -> ScenarioTemplateView:
"""Replace the entire ordered test list. `position` becomes the index."""
with session_scope() as s:
sc = s.get(ScenarioTemplate, scenario_id)
if sc is None or sc.deleted_at is not None:
raise ScenarioTemplateNotFound()
_validate_test_ids(s, test_template_ids)
# Wipe then re-insert. The UNIQUE(position) constraint forbids a
# naive UPDATE-swap; full-replace keeps the op atomic + readable.
for link in list(sc.tests):
s.delete(link)
s.flush()
for position, tid in enumerate(test_template_ids):
s.add(
ScenarioTemplateTest(
scenario_template_id=sc.id,
test_template_id=tid,
position=position,
)
)
s.flush()
s.refresh(sc)
return _to_view(s, sc)
def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None:
with session_scope() as s:
sc = s.get(ScenarioTemplate, scenario_id)
if sc is None or sc.deleted_at is not None:
raise ScenarioTemplateNotFound()
sc.deleted_at = datetime.now(tz=timezone.utc)

View File

@@ -0,0 +1,395 @@
"""CRUD service for `test_templates` + their MITRE tags.
The MITRE tag set is **fully replaced** on every update — partial mutation of
the join rows would force the API client to track tag UUIDs they never created.
The polymorphic join (one of `tactic_id` / `technique_id` / `subtechnique_id`
populated) is owned here: callers pass `(kind, external_id)` tuples and we
resolve them to the matching MITRE row.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Iterable
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, selectinload
_UNSET: Any = object()
from app.db.session import session_scope
from app.db.types import MITRE_KINDS, OPSEC_LEVELS
from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique
from app.models.template import TestTemplate, TestTemplateMitreTag
class TestTemplateNotFound(Exception):
pass
class UnknownMitreTag(Exception):
"""Raised when an (kind, external_id) tuple doesn't resolve to a known MITRE row."""
@dataclass(frozen=True)
class MitreTagRef:
"""Inbound MITRE tag reference. `external_id` is the ATT&CK identifier
(TA…/T…/T….…) — we resolve it server-side, the client never sees UUIDs.
"""
kind: str # "tactic" | "technique" | "subtechnique"
external_id: str
@dataclass(frozen=True)
class MitreTagView:
kind: str
external_id: str
name: str
url: str | None
@dataclass(frozen=True)
class TestTemplateView:
id: uuid.UUID
name: str
description: str | None
objective: str | None
procedure_md: str | None
prerequisites_md: str | None
expected_result_red_md: str | None
expected_detection_blue_md: str | None
opsec_level: str
tags: list[str]
expected_iocs: list[str]
mitre_tags: list[MitreTagView]
deleted_at: datetime | None
created_at: datetime
updated_at: datetime
def _validate_opsec(value: str) -> str:
if value not in OPSEC_LEVELS:
raise ValueError(f"opsec_level must be one of {OPSEC_LEVELS}")
return value
def _normalize_string_list(values: Iterable[str] | None) -> list[str]:
if not values:
return []
seen: set[str] = set()
out: list[str] = []
for raw in values:
if not isinstance(raw, str):
raise ValueError("list items must be strings")
v = raw.strip()
if not v or v in seen:
continue
seen.add(v)
out.append(v)
return out
def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplateMitreTag]:
"""Translate `(kind, external_id)` pairs into half-populated join rows.
Validates that:
- `kind` is one of the supported values
- each external_id resolves to an existing MITRE row
- the combination is unique inside the payload (de-duped silently — same
tag twice is a no-op, not an error)
"""
if not refs:
return []
# Dedupe input
deduped: dict[tuple[str, str], MitreTagRef] = {}
for ref in refs:
if ref.kind not in MITRE_KINDS:
raise ValueError(f"mitre tag kind must be one of {MITRE_KINDS}")
if not ref.external_id:
raise ValueError("mitre tag external_id is required")
deduped[(ref.kind, ref.external_id)] = ref
tactic_ids = {r.external_id for r in deduped.values() if r.kind == "tactic"}
technique_ids = {r.external_id for r in deduped.values() if r.kind == "technique"}
subtechnique_ids = {r.external_id for r in deduped.values() if r.kind == "subtechnique"}
tactic_map = {
t.external_id: t.id
for t in s.scalars(select(MitreTactic).where(MitreTactic.external_id.in_(tactic_ids))).all()
}
technique_map = {
t.external_id: t.id
for t in s.scalars(select(MitreTechnique).where(MitreTechnique.external_id.in_(technique_ids))).all()
}
subtechnique_map = {
sb.external_id: sb.id
for sb in s.scalars(
select(MitreSubtechnique).where(MitreSubtechnique.external_id.in_(subtechnique_ids))
).all()
}
rows: list[TestTemplateMitreTag] = []
missing: list[tuple[str, str]] = []
for ref in deduped.values():
if ref.kind == "tactic":
mid = tactic_map.get(ref.external_id)
if mid is None:
missing.append((ref.kind, ref.external_id))
continue
rows.append(TestTemplateMitreTag(mitre_kind="tactic", tactic_id=mid))
elif ref.kind == "technique":
mid = technique_map.get(ref.external_id)
if mid is None:
missing.append((ref.kind, ref.external_id))
continue
rows.append(TestTemplateMitreTag(mitre_kind="technique", technique_id=mid))
else:
mid = subtechnique_map.get(ref.external_id)
if mid is None:
missing.append((ref.kind, ref.external_id))
continue
rows.append(TestTemplateMitreTag(mitre_kind="subtechnique", subtechnique_id=mid))
if missing:
raise UnknownMitreTag(f"unknown MITRE tags: {sorted(missing)}")
return rows
def _to_view(s: Session, t: TestTemplate) -> TestTemplateView:
tag_views: list[MitreTagView] = []
for tag in t.mitre_tags:
if tag.mitre_kind == "tactic" and tag.tactic_id is not None:
row = s.get(MitreTactic, tag.tactic_id)
if row is not None:
tag_views.append(MitreTagView(kind="tactic", external_id=row.external_id, name=row.name, url=row.url))
elif tag.mitre_kind == "technique" and tag.technique_id is not None:
row = s.get(MitreTechnique, tag.technique_id)
if row is not None:
tag_views.append(MitreTagView(kind="technique", external_id=row.external_id, name=row.name, url=row.url))
elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None:
row = s.get(MitreSubtechnique, tag.subtechnique_id)
if row is not None:
tag_views.append(
MitreTagView(kind="subtechnique", external_id=row.external_id, name=row.name, url=row.url)
)
tag_views.sort(key=lambda v: (v.kind, v.external_id))
return TestTemplateView(
id=t.id,
name=t.name,
description=t.description,
objective=t.objective,
procedure_md=t.procedure_md,
prerequisites_md=t.prerequisites_md,
expected_result_red_md=t.expected_result_red_md,
expected_detection_blue_md=t.expected_detection_blue_md,
opsec_level=t.opsec_level,
tags=list(t.tags or []),
expected_iocs=list(t.expected_iocs or []),
mitre_tags=tag_views,
deleted_at=t.deleted_at,
created_at=t.created_at,
updated_at=t.updated_at,
)
def _base_query():
return select(TestTemplate).options(selectinload(TestTemplate.mitre_tags))
def list_test_templates(
*,
q: str | None = None,
tactic: str | None = None, # external_id like "TA0006"
technique: str | None = None,
subtechnique: str | None = None,
opsec_level: str | None = None,
tag: str | None = None,
include_deleted: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[TestTemplateView], int]:
with session_scope() as s:
stmt = _base_query().order_by(TestTemplate.name.asc())
count_stmt = select(func.count()).select_from(TestTemplate)
if not include_deleted:
stmt = stmt.where(TestTemplate.deleted_at.is_(None))
count_stmt = count_stmt.where(TestTemplate.deleted_at.is_(None))
if q:
like = f"%{q.lower()}%"
cond = or_(
func.lower(TestTemplate.name).like(like),
func.lower(TestTemplate.description).like(like),
)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if opsec_level:
_validate_opsec(opsec_level)
stmt = stmt.where(TestTemplate.opsec_level == opsec_level)
count_stmt = count_stmt.where(TestTemplate.opsec_level == opsec_level)
if tag:
stmt = stmt.where(TestTemplate.tags.any(tag))
count_stmt = count_stmt.where(TestTemplate.tags.any(tag))
# MITRE facet: resolve external_id → uuid then filter via join subquery.
if tactic or technique or subtechnique:
tag_ids: list[uuid.UUID] = []
if tactic:
tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic))
if tac is None:
return [], 0
tag_ids.append(tac.id)
if technique:
tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique))
if tech is None:
return [], 0
tag_ids.append(tech.id)
if subtechnique:
sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique))
if sub is None:
return [], 0
tag_ids.append(sub.id)
sub_q = (
select(TestTemplateMitreTag.test_template_id)
.where(
or_(
TestTemplateMitreTag.tactic_id.in_(tag_ids),
TestTemplateMitreTag.technique_id.in_(tag_ids),
TestTemplateMitreTag.subtechnique_id.in_(tag_ids),
)
)
.distinct()
)
stmt = stmt.where(TestTemplate.id.in_(sub_q))
count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q))
total = s.scalar(count_stmt) or 0
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
return [_to_view(s, t) for t in rows], int(total)
def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView:
with session_scope() as s:
t = s.get(TestTemplate, template_id)
if t is None:
raise TestTemplateNotFound()
if t.deleted_at is not None and not include_deleted:
raise TestTemplateNotFound()
return _to_view(s, t)
def create_test_template(
*,
name: str,
description: str | None = None,
objective: str | None = None,
procedure_md: str | None = None,
prerequisites_md: str | None = None,
expected_result_red_md: str | None = None,
expected_detection_blue_md: str | None = None,
opsec_level: str = "medium",
tags: list[str] | None = None,
expected_iocs: list[str] | None = None,
mitre_tags: list[MitreTagRef] | None = None,
) -> TestTemplateView:
name_norm = (name or "").strip()
if not name_norm:
raise ValueError("name is required")
_validate_opsec(opsec_level)
norm_tags = _normalize_string_list(tags)
norm_iocs = _normalize_string_list(expected_iocs)
with session_scope() as s:
t = TestTemplate(
name=name_norm,
description=_opt_str(description),
objective=_opt_str(objective),
procedure_md=procedure_md or None,
prerequisites_md=prerequisites_md or None,
expected_result_red_md=expected_result_red_md or None,
expected_detection_blue_md=expected_detection_blue_md or None,
opsec_level=opsec_level,
tags=norm_tags,
expected_iocs=norm_iocs,
)
s.add(t)
s.flush()
if mitre_tags:
rows = _resolve_mitre_refs(s, mitre_tags)
for row in rows:
row.test_template_id = t.id
s.add(row)
s.flush()
s.refresh(t)
return _to_view(s, t)
def _opt_str(value: str | None) -> str | None:
if value is None:
return None
s = value.strip()
return s or None
def update_test_template(
template_id: uuid.UUID,
*,
name: str | None = None,
description: Any = _UNSET,
objective: Any = _UNSET,
procedure_md: Any = _UNSET,
prerequisites_md: Any = _UNSET,
expected_result_red_md: Any = _UNSET,
expected_detection_blue_md: Any = _UNSET,
opsec_level: str | None = None,
tags: Any = _UNSET,
expected_iocs: Any = _UNSET,
mitre_tags: Any = _UNSET,
) -> TestTemplateView:
with session_scope() as s:
t = s.get(TestTemplate, template_id)
if t is None or t.deleted_at is not None:
raise TestTemplateNotFound()
if name is not None:
n = name.strip()
if not n:
raise ValueError("name cannot be empty")
t.name = n
if description is not _UNSET:
t.description = _opt_str(description)
if objective is not _UNSET:
t.objective = _opt_str(objective)
if procedure_md is not _UNSET:
t.procedure_md = procedure_md or None
if prerequisites_md is not _UNSET:
t.prerequisites_md = prerequisites_md or None
if expected_result_red_md is not _UNSET:
t.expected_result_red_md = expected_result_red_md or None
if expected_detection_blue_md is not _UNSET:
t.expected_detection_blue_md = expected_detection_blue_md or None
if opsec_level is not None:
_validate_opsec(opsec_level)
t.opsec_level = opsec_level
if tags is not _UNSET:
t.tags = _normalize_string_list(tags)
if expected_iocs is not _UNSET:
t.expected_iocs = _normalize_string_list(expected_iocs)
if mitre_tags is not _UNSET:
for row in list(t.mitre_tags):
s.delete(row)
s.flush()
rows = _resolve_mitre_refs(s, list(mitre_tags or []))
for row in rows:
row.test_template_id = t.id
s.add(row)
s.flush()
s.refresh(t)
return _to_view(s, t)
def soft_delete_test_template(template_id: uuid.UUID) -> None:
with session_scope() as s:
t = s.get(TestTemplate, template_id)
if t is None or t.deleted_at is not None:
raise TestTemplateNotFound()
t.deleted_at = datetime.now(tz=timezone.utc)

View File

@@ -0,0 +1,492 @@
"""M5 — Template catalogue integration tests.
Covers `test_template` and `scenario_template` CRUD + ordering + perm gating.
Relies on a minimal MITRE seed (T1059 / TA0001 / T1059.001) so the polymorphic
tag join can be exercised end-to-end.
"""
from __future__ import annotations
import json
import secrets
import pytest
from sqlalchemy import text
from app.core.install_token import regenerate_install_token
from app.main import create_app
from app.services import mitre_seed as mitre_svc
def _truncate_all(engine):
with engine.begin() as conn:
conn.execute(
text(
"TRUNCATE users, refresh_tokens, invitations, invitation_groups, "
"user_groups, group_permissions, permissions, settings, groups, "
"scenario_template_tests, scenario_templates, "
"test_template_mitre_tags, test_templates, "
"mitre_subtechniques, mitre_technique_tactics, mitre_techniques, "
"mitre_tactics RESTART IDENTITY CASCADE"
)
)
# Same minimal bundle as in test_mitre.py — keeps tag resolution deterministic
# without re-pulling the full enterprise STIX bundle.
_MINIMAL_BUNDLE = {
"type": "bundle",
"id": "bundle--00000000-0000-0000-0000-000000000002",
"spec_version": "2.1",
"objects": [
{
"type": "x-mitre-tactic",
"id": "x-mitre-tactic--ta0001",
"name": "Initial Access",
"x_mitre_shortname": "initial-access",
"external_references": [
{"source_name": "mitre-attack", "external_id": "TA0001"}
],
},
{
"type": "x-mitre-tactic",
"id": "x-mitre-tactic--ta0002",
"name": "Execution",
"x_mitre_shortname": "execution",
"external_references": [
{"source_name": "mitre-attack", "external_id": "TA0002"}
],
},
{
"type": "attack-pattern",
"id": "attack-pattern--t1059",
"name": "Command and Scripting Interpreter",
"kill_chain_phases": [
{"kill_chain_name": "mitre-attack", "phase_name": "execution"}
],
"external_references": [
{"source_name": "mitre-attack", "external_id": "T1059"}
],
},
{
"type": "attack-pattern",
"id": "attack-pattern--t1059-001",
"name": "PowerShell",
"x_mitre_is_subtechnique": True,
"external_references": [
{"source_name": "mitre-attack", "external_id": "T1059.001"}
],
},
{
"type": "relationship",
"id": "relationship--rel1",
"relationship_type": "subtechnique-of",
"source_ref": "attack-pattern--t1059-001",
"target_ref": "attack-pattern--t1059",
},
],
}
@pytest.fixture(scope="module")
def app(db_engine_or_skip, tmp_path_factory):
_truncate_all(db_engine_or_skip)
bundle_path = tmp_path_factory.mktemp("m5") / "stix.json"
bundle_path.write_text(json.dumps(_MINIMAL_BUNDLE))
mitre_svc.seed_mitre(source=bundle_path, expected_sha256=None)
flask_app = create_app()
flask_app.config.update(TESTING=True)
return flask_app
@pytest.fixture()
def client(app):
return app.test_client()
def _unique_email(prefix: str) -> str:
return f"{prefix}-{secrets.token_hex(4)}@metamorph.local"
@pytest.fixture(scope="module")
def admin(app):
token = regenerate_install_token()
email = _unique_email("admin")
password = "AdminPass1234!"
with app.test_client() as c:
r = c.post(
"/api/v1/setup",
json={"install_token": token, "email": email, "password": password},
)
assert r.status_code == 201, r.get_data(as_text=True)
return {"email": email, "password": password}
def _login(client, email: str, password: str) -> str:
r = client.post("/api/v1/auth/login", json={"email": email, "password": password})
assert r.status_code == 200, r.get_data(as_text=True)
return r.get_json()["access_token"]
def _bearer(token: str) -> dict[str, str]:
return {"Authorization": f"Bearer {token}"}
@pytest.fixture()
def admin_token(client, admin) -> str:
return _login(client, admin["email"], admin["password"])
# === Reader fixture: an invited user with only `test_template.read` =========
def _bootstrap_user_without_perms(client, admin_token: str, prefix: str) -> tuple[str, str]:
email = _unique_email(prefix)
inv = client.post(
"/api/v1/invitations",
headers=_bearer(admin_token),
json={"email_hint": email},
)
token = inv.get_json()["token"]
password = "ReaderPass1234!"
client.post(
f"/api/v1/invitations/accept/{token}",
json={"email": email, "password": password},
)
return email, _login(client, email, password)
# === test_template CRUD =====================================================
def _make_test(client, admin_token: str, **overrides):
body = {
"name": overrides.pop("name", f"Test {secrets.token_hex(2)}"),
"description": overrides.pop("description", "auto"),
"objective": "do thing",
"procedure_md": "1. step",
"expected_result_red_md": "red expectation",
"expected_detection_blue_md": "blue expectation",
"opsec_level": overrides.pop("opsec_level", "medium"),
"tags": overrides.pop("tags", ["fast"]),
"expected_iocs": ["evil.exe"],
"mitre_tags": overrides.pop("mitre_tags", [{"kind": "technique", "external_id": "T1059"}]),
**overrides,
}
r = client.post("/api/v1/test-templates", headers=_bearer(admin_token), json=body)
assert r.status_code == 201, r.get_data(as_text=True)
return r.get_json()
def test_create_test_template_with_mitre_tags(client, admin_token):
body = _make_test(
client,
admin_token,
name="PowerShell exec",
mitre_tags=[
{"kind": "tactic", "external_id": "TA0002"},
{"kind": "technique", "external_id": "T1059"},
{"kind": "subtechnique", "external_id": "T1059.001"},
],
)
assert body["opsec_level"] == "medium"
kinds = sorted((t["kind"], t["external_id"]) for t in body["mitre_tags"])
assert kinds == [
("subtechnique", "T1059.001"),
("tactic", "TA0002"),
("technique", "T1059"),
]
def test_create_test_template_rejects_unknown_mitre(client, admin_token):
r = client.post(
"/api/v1/test-templates",
headers=_bearer(admin_token),
json={
"name": "Bad",
"mitre_tags": [{"kind": "technique", "external_id": "T9999"}],
},
)
assert r.status_code == 400
assert r.get_json()["error"] == "unknown_mitre_tag"
def test_create_test_template_rejects_bad_opsec(client, admin_token):
r = client.post(
"/api/v1/test-templates",
headers=_bearer(admin_token),
json={"name": "Bad", "opsec_level": "burner"},
)
assert r.status_code == 400
def test_list_test_templates_filter_by_tactic(client, admin_token):
_make_test(
client,
admin_token,
name="filterable-1",
mitre_tags=[{"kind": "tactic", "external_id": "TA0002"}],
)
r = client.get(
"/api/v1/test-templates?tactic=TA0002",
headers=_bearer(admin_token),
)
assert r.status_code == 200
body = r.get_json()
names = [it["name"] for it in body["items"]]
assert "filterable-1" in names
def test_list_test_templates_filter_by_opsec(client, admin_token):
_make_test(client, admin_token, name="high-opsec", opsec_level="high")
r = client.get(
"/api/v1/test-templates?opsec=high",
headers=_bearer(admin_token),
)
assert r.status_code == 200
names = [it["name"] for it in r.get_json()["items"]]
assert "high-opsec" in names
assert all(it["opsec_level"] == "high" for it in r.get_json()["items"])
def test_list_test_templates_filter_by_tag(client, admin_token):
_make_test(client, admin_token, name="tagged-fast", tags=["fast", "phish"])
r = client.get(
"/api/v1/test-templates?tag=phish",
headers=_bearer(admin_token),
)
assert r.status_code == 200
names = [it["name"] for it in r.get_json()["items"]]
assert "tagged-fast" in names
def test_list_test_templates_search_q(client, admin_token):
_make_test(client, admin_token, name="unique-token-azertyuiop")
r = client.get(
"/api/v1/test-templates?q=AZERTYUIOP", # case-insensitive
headers=_bearer(admin_token),
)
assert r.status_code == 200
names = [it["name"] for it in r.get_json()["items"]]
assert "unique-token-azertyuiop" in names
def test_update_test_template_replaces_mitre_tags(client, admin_token):
body = _make_test(
client,
admin_token,
name="to-update",
mitre_tags=[{"kind": "tactic", "external_id": "TA0001"}],
)
r = client.put(
f"/api/v1/test-templates/{body['id']}",
headers=_bearer(admin_token),
json={"mitre_tags": [{"kind": "technique", "external_id": "T1059"}]},
)
assert r.status_code == 200, r.get_data(as_text=True)
updated = r.get_json()
kinds = [(t["kind"], t["external_id"]) for t in updated["mitre_tags"]]
assert kinds == [("technique", "T1059")]
def test_update_test_template_partial_keeps_unset_fields(client, admin_token):
body = _make_test(
client,
admin_token,
name="partial-update",
opsec_level="low",
tags=["a", "b"],
)
r = client.put(
f"/api/v1/test-templates/{body['id']}",
headers=_bearer(admin_token),
json={"name": "renamed"},
)
assert r.status_code == 200
updated = r.get_json()
assert updated["name"] == "renamed"
assert updated["opsec_level"] == "low" # untouched
assert set(updated["tags"]) == {"a", "b"} # untouched
def test_soft_delete_then_list_hides_by_default(client, admin_token):
body = _make_test(client, admin_token, name="to-be-deleted")
r = client.delete(
f"/api/v1/test-templates/{body['id']}", headers=_bearer(admin_token)
)
assert r.status_code == 200
r2 = client.get("/api/v1/test-templates", headers=_bearer(admin_token))
names = [it["name"] for it in r2.get_json()["items"]]
assert "to-be-deleted" not in names
# And reappears with include_deleted=true
r3 = client.get(
"/api/v1/test-templates?include_deleted=true",
headers=_bearer(admin_token),
)
names3 = [it["name"] for it in r3.get_json()["items"]]
assert "to-be-deleted" in names3
def test_read_perm_required(client, admin_token):
"""A user without `test_template.read` gets 403."""
_, eve_token = _bootstrap_user_without_perms(client, admin_token, "eve-noperm")
r = client.get("/api/v1/test-templates", headers=_bearer(eve_token))
assert r.status_code == 403
def test_write_perm_required(client, admin_token):
"""A user with only `test_template.read` cannot create.
Bootstrap path: create a dedicated group via the admin API, bind only the
`test_template.read` perm, then invite a user pre-assigned to that group.
"""
# 1. Create the read-only group + bind the single perm.
grp = client.post(
"/api/v1/groups",
headers=_bearer(admin_token),
json={"name": f"tpl-reader-{secrets.token_hex(2)}"},
).get_json()
r_set = client.put(
f"/api/v1/groups/{grp['id']}/permissions",
headers=_bearer(admin_token),
json={"codes": ["test_template.read"]},
)
assert r_set.status_code == 200, r_set.get_data(as_text=True)
# 2. Invite a user already attached to that group.
email = _unique_email("alice-readonly")
password = "ReaderPass1234!"
inv = client.post(
"/api/v1/invitations",
headers=_bearer(admin_token),
json={"email_hint": email, "group_ids": [grp["id"]]},
).get_json()
client.post(
f"/api/v1/invitations/accept/{inv['token']}",
json={"email": email, "password": password},
)
token = _login(client, email, password)
r = client.get("/api/v1/test-templates", headers=_bearer(token))
assert r.status_code == 200, r.get_data(as_text=True)
r2 = client.post(
"/api/v1/test-templates", headers=_bearer(token), json={"name": "X"}
)
assert r2.status_code == 403
# === scenario_template CRUD =================================================
def test_create_scenario_with_ordered_tests(client, admin_token):
a = _make_test(client, admin_token, name="scn-a")
b = _make_test(client, admin_token, name="scn-b")
c = _make_test(client, admin_token, name="scn-c")
r = client.post(
"/api/v1/scenario-templates",
headers=_bearer(admin_token),
json={
"name": "phishing-flow",
"description": "click → exec → persist",
"test_template_ids": [a["id"], b["id"], c["id"]],
},
)
assert r.status_code == 201, r.get_data(as_text=True)
body = r.get_json()
assert body["tests_count"] == 3
assert [t["position"] for t in body["tests"]] == [0, 1, 2]
assert [t["test_template_name"] for t in body["tests"]] == ["scn-a", "scn-b", "scn-c"]
def test_reorder_scenario_tests(client, admin_token):
a = _make_test(client, admin_token, name="reord-a")
b = _make_test(client, admin_token, name="reord-b")
c = _make_test(client, admin_token, name="reord-c")
created = client.post(
"/api/v1/scenario-templates",
headers=_bearer(admin_token),
json={
"name": "reorder-me",
"test_template_ids": [a["id"], b["id"], c["id"]],
},
).get_json()
# Reverse order.
r = client.put(
f"/api/v1/scenario-templates/{created['id']}/tests",
headers=_bearer(admin_token),
json={"test_template_ids": [c["id"], b["id"], a["id"]]},
)
assert r.status_code == 200
after = r.get_json()
assert [t["test_template_name"] for t in after["tests"]] == ["reord-c", "reord-b", "reord-a"]
# Re-reading via GET yields the same order — confirms persistence.
fresh = client.get(
f"/api/v1/scenario-templates/{created['id']}", headers=_bearer(admin_token)
).get_json()
assert [t["test_template_name"] for t in fresh["tests"]] == ["reord-c", "reord-b", "reord-a"]
def test_scenario_rejects_unknown_test_id(client, admin_token):
r = client.post(
"/api/v1/scenario-templates",
headers=_bearer(admin_token),
json={
"name": "bad",
"test_template_ids": ["00000000-0000-0000-0000-000000000000"],
},
)
assert r.status_code == 400
assert r.get_json()["error"] == "unknown_test_template"
def test_scenario_rejects_soft_deleted_test_on_create(client, admin_token):
a = _make_test(client, admin_token, name="will-be-deleted")
client.delete(f"/api/v1/test-templates/{a['id']}", headers=_bearer(admin_token))
r = client.post(
"/api/v1/scenario-templates",
headers=_bearer(admin_token),
json={"name": "linked", "test_template_ids": [a["id"]]},
)
assert r.status_code == 400
assert r.get_json()["error"] == "unknown_test_template"
def test_scenario_surfaces_soft_deleted_test_after_link(client, admin_token):
"""Once linked, a test can be soft-deleted without breaking the scenario —
the join row stays and the API flags the test as deleted."""
a = _make_test(client, admin_token, name="linked-then-deleted")
sc = client.post(
"/api/v1/scenario-templates",
headers=_bearer(admin_token),
json={"name": "survives", "test_template_ids": [a["id"]]},
).get_json()
client.delete(f"/api/v1/test-templates/{a['id']}", headers=_bearer(admin_token))
fresh = client.get(
f"/api/v1/scenario-templates/{sc['id']}", headers=_bearer(admin_token)
).get_json()
assert fresh["tests"][0]["test_template_deleted"] is True
def test_scenario_soft_delete(client, admin_token):
sc = client.post(
"/api/v1/scenario-templates",
headers=_bearer(admin_token),
json={"name": "doomed-scn"},
).get_json()
r = client.delete(
f"/api/v1/scenario-templates/{sc['id']}", headers=_bearer(admin_token)
)
assert r.status_code == 200
names = [
it["name"]
for it in client.get(
"/api/v1/scenario-templates", headers=_bearer(admin_token)
).get_json()["items"]
]
assert "doomed-scn" not in names
def test_scenario_perm_required(client, admin_token):
_, eve_token = _bootstrap_user_without_perms(client, admin_token, "scn-eve")
r = client.get("/api/v1/scenario-templates", headers=_bearer(eve_token))
assert r.status_code == 403