From b8fd99a5f4bcb991d75698c54d14ea47867b4465 Mon Sep 17 00:00:00 2001 From: Knacky Date: Tue, 12 May 2026 19:57:33 +0200 Subject: [PATCH 1/6] feat(m5): test_template + scenario_template CRUD with MITRE tags and ordered tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Service `app/services/test_templates.py`: CRUD with MITRE tag resolution (kind, external_id) → polymorphic join, filters by tactic/technique/ subtechnique/opsec/tag, `_UNSET` sentinel for partial-update semantics. - Service `app/services/scenario_templates.py`: ordered test list, reorder via full-replace (atomic w.r.t. UNIQUE(position) constraint), soft-delete. - REST endpoints on /api/v1/test-templates and /scenario-templates with pydantic schemas + perm gating (test_template.* and scenario_template.*). - /diag/reset truncates the 4 new tables before MITRE (FK ordering). - 19 pytest covering CRUD, MITRE tag merge, soft-delete chaining, perm enforcement, and reorder atomicity. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/app/api/diag.py | 11 + backend/app/api/scenario_templates.py | 208 +++++++++ backend/app/api/test_templates.py | 250 +++++++++++ backend/app/api/v1.py | 4 + backend/app/services/scenario_templates.py | 240 ++++++++++ backend/app/services/test_templates.py | 395 +++++++++++++++++ backend/tests/test_templates.py | 492 +++++++++++++++++++++ 7 files changed, 1600 insertions(+) create mode 100644 backend/app/api/scenario_templates.py create mode 100644 backend/app/api/test_templates.py create mode 100644 backend/app/services/scenario_templates.py create mode 100644 backend/app/services/test_templates.py create mode 100644 backend/tests/test_templates.py diff --git a/backend/app/api/diag.py b/backend/app/api/diag.py index 4fa3880..8dd0e20 100644 --- a/backend/app/api/diag.py +++ b/backend/app/api/diag.py @@ -73,6 +73,17 @@ def reset_test_state(): "user_groups, settings, groups RESTART IDENTITY CASCADE" ) ) + # Template catalogue reset (M5). The MITRE truncate below cascades to + # the polymorphic tag join, but the template rows themselves must be + # wiped first because `scenario_template_tests.test_template_id` is + # ON DELETE RESTRICT. + conn.execute( + text( + "TRUNCATE scenario_template_tests, scenario_templates, " + "test_template_mitre_tags, test_templates " + "RESTART IDENTITY CASCADE" + ) + ) # MITRE reference reset — kept in sync with `settings` so a freshly # reset stack has `GET /mitre/status` and `GET /mitre/tactics` agree # ("no data, no last_sync"). The e2e suite re-syncs via /mitre/sync diff --git a/backend/app/api/scenario_templates.py b/backend/app/api/scenario_templates.py new file mode 100644 index 0000000..4481f5d --- /dev/null +++ b/backend/app/api/scenario_templates.py @@ -0,0 +1,208 @@ +"""Scenario-template CRUD + reorder endpoints. + +`PUT //tests` is the reorder/replace endpoint — it takes the full ordered +list and rewrites the join rows. There's no partial mutation API for the test +list: the wire contract is simpler and the client (drag-and-drop) already +holds the full ordering. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from flask import Blueprint, jsonify, request +from pydantic import BaseModel, Field, ValidationError + +from app.core.auth_decorators import require_auth, require_perm +from app.services import scenario_templates as svc + +bp = Blueprint("scenario_templates", __name__, url_prefix="/scenario-templates") +log = logging.getLogger("metamorph.api.scenario_templates") + + +class CreateScenarioPayload(BaseModel): + name: str = Field(min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + test_template_ids: list[uuid.UUID] = Field(default_factory=list, max_length=512) + + model_config = {"extra": "forbid"} + + +class UpdateScenarioPayload(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + + model_config = {"extra": "forbid"} + + +class SetTestsPayload(BaseModel): + test_template_ids: list[uuid.UUID] = Field(default_factory=list, max_length=512) + + model_config = {"extra": "forbid"} + + +def _serialize(sc: svc.ScenarioTemplateView) -> dict[str, Any]: + return { + "id": str(sc.id), + "name": sc.name, + "description": sc.description, + "tests": [ + { + "position": t.position, + "test_template_id": str(t.test_template_id), + "test_template_name": t.test_template_name, + "test_template_deleted": t.test_template_deleted, + } + for t in sc.tests + ], + "tests_count": sc.tests_count, + "deleted_at": sc.deleted_at.isoformat() if sc.deleted_at else None, + "created_at": sc.created_at.isoformat(), + "updated_at": sc.updated_at.isoformat(), + } + + +def _parse_uuid_or_400(raw: str): + try: + return uuid.UUID(raw) + except ValueError: + return None + + +def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]: + try: + limit = int(request.args.get("limit", "100")) + offset = int(request.args.get("offset", "0")) + except ValueError: + return None, (400, "invalid_pagination") + return max(1, min(limit, 500)), max(0, offset) + + +@bp.get("") +@require_auth +@require_perm("scenario_template.read") +def list_scenario_templates(): + paging = _pagination_args() + if paging[0] is None: + return jsonify({"error": paging[1][1]}), paging[1][0] + limit, offset = paging + q = request.args.get("q") or None + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + items, total = svc.list_scenario_templates( + q=q, include_deleted=include_deleted, limit=limit, offset=offset + ) + return jsonify( + { + "items": [_serialize(it) for it in items], + "total": total, + "limit": limit, + "offset": offset, + } + ) + + +@bp.get("/") +@require_auth +@require_perm("scenario_template.read") +def get_scenario_template(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + try: + view = svc.get_scenario_template(sid, include_deleted=include_deleted) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + return jsonify(_serialize(view)) + + +@bp.post("") +@require_auth +@require_perm("scenario_template.create") +def create_scenario_template(): + try: + payload = CreateScenarioPayload.model_validate(request.get_json(silent=True) or {}) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + try: + view = svc.create_scenario_template( + name=payload.name, + description=payload.description, + test_template_ids=payload.test_template_ids, + ) + except svc.UnknownTestTemplate as e: + return jsonify({"error": "unknown_test_template", "message": str(e)}), 400 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + log.info( + "metamorph.scenario_template.created", + extra={"id": str(view.id), "tests": len(view.tests)}, + ) + return jsonify(_serialize(view)), 201 + + +@bp.patch("/") +@require_auth +@require_perm("scenario_template.update") +def update_scenario_template(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + raw = request.get_json(silent=True) or {} + try: + payload = UpdateScenarioPayload.model_validate(raw) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + kwargs: dict[str, Any] = {} + if "name" in raw: + kwargs["name"] = payload.name + if "description" in raw: + kwargs["description"] = payload.description + try: + view = svc.update_scenario_template(sid, **kwargs) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + return jsonify(_serialize(view)) + + +@bp.put("//tests") +@require_auth +@require_perm("scenario_template.update") +def set_scenario_tests(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + try: + payload = SetTestsPayload.model_validate(request.get_json(silent=True) or {}) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + try: + view = svc.set_scenario_tests(sid, payload.test_template_ids) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + except svc.UnknownTestTemplate as e: + return jsonify({"error": "unknown_test_template", "message": str(e)}), 400 + log.info( + "metamorph.scenario_template.tests_set", + extra={"id": str(sid), "tests": len(view.tests)}, + ) + return jsonify(_serialize(view)) + + +@bp.delete("/") +@require_auth +@require_perm("scenario_template.delete") +def soft_delete_scenario_template(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + try: + svc.soft_delete_scenario_template(sid) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + log.info("metamorph.scenario_template.soft_deleted", extra={"id": str(sid)}) + return jsonify({"ok": True}) diff --git a/backend/app/api/test_templates.py b/backend/app/api/test_templates.py new file mode 100644 index 0000000..86754f4 --- /dev/null +++ b/backend/app/api/test_templates.py @@ -0,0 +1,250 @@ +"""Test-template CRUD endpoints. + +Reads gated by `test_template.read`. Writes gated by `test_template.{create, +update,delete}`. Service layer handles all DB work; this module only validates +the wire payload and shapes the JSON response. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from flask import Blueprint, jsonify, request +from pydantic import BaseModel, Field, ValidationError + +from app.core.auth_decorators import require_auth, require_perm +from app.services import test_templates as svc + +bp = Blueprint("test_templates", __name__, url_prefix="/test-templates") +log = logging.getLogger("metamorph.api.test_templates") + + +# === Payload schemas ========================================================== + + +class MitreTagIn(BaseModel): + kind: str = Field(min_length=1) + external_id: str = Field(min_length=1, max_length=16) + + model_config = {"extra": "forbid"} + + +class CreateTestTemplatePayload(BaseModel): + name: str = Field(min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + objective: str | None = Field(default=None, max_length=4000) + procedure_md: str | None = Field(default=None, max_length=32_000) + prerequisites_md: str | None = Field(default=None, max_length=32_000) + expected_result_red_md: str | None = Field(default=None, max_length=32_000) + expected_detection_blue_md: str | None = Field(default=None, max_length=32_000) + opsec_level: str = Field(default="medium") + tags: list[str] = Field(default_factory=list, max_length=64) + expected_iocs: list[str] = Field(default_factory=list, max_length=128) + mitre_tags: list[MitreTagIn] = Field(default_factory=list, max_length=64) + + model_config = {"extra": "forbid"} + + +class UpdateTestTemplatePayload(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + objective: str | None = Field(default=None, max_length=4000) + procedure_md: str | None = Field(default=None, max_length=32_000) + prerequisites_md: str | None = Field(default=None, max_length=32_000) + expected_result_red_md: str | None = Field(default=None, max_length=32_000) + expected_detection_blue_md: str | None = Field(default=None, max_length=32_000) + opsec_level: str | None = None + tags: list[str] | None = Field(default=None, max_length=64) + expected_iocs: list[str] | None = Field(default=None, max_length=128) + mitre_tags: list[MitreTagIn] | None = Field(default=None, max_length=64) + + model_config = {"extra": "forbid"} + + +# === Serializers ============================================================== + + +def _serialize(t: svc.TestTemplateView) -> dict[str, Any]: + return { + "id": str(t.id), + "name": t.name, + "description": t.description, + "objective": t.objective, + "procedure_md": t.procedure_md, + "prerequisites_md": t.prerequisites_md, + "expected_result_red_md": t.expected_result_red_md, + "expected_detection_blue_md": t.expected_detection_blue_md, + "opsec_level": t.opsec_level, + "tags": list(t.tags), + "expected_iocs": list(t.expected_iocs), + "mitre_tags": [ + {"kind": tag.kind, "external_id": tag.external_id, "name": tag.name, "url": tag.url} + for tag in t.mitre_tags + ], + "deleted_at": t.deleted_at.isoformat() if t.deleted_at else None, + "created_at": t.created_at.isoformat(), + "updated_at": t.updated_at.isoformat(), + } + + +def _parse_uuid_or_400(raw: str): + try: + return uuid.UUID(raw) + except ValueError: + return None + + +def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]: + try: + limit = int(request.args.get("limit", "100")) + offset = int(request.args.get("offset", "0")) + except ValueError: + return None, (400, "invalid_pagination") + return max(1, min(limit, 500)), max(0, offset) + + +# === Endpoints ================================================================ + + +@bp.get("") +@require_auth +@require_perm("test_template.read") +def list_test_templates(): + paging = _pagination_args() + if paging[0] is None: + return jsonify({"error": paging[1][1]}), paging[1][0] + limit, offset = paging + q = request.args.get("q") or None + tactic = request.args.get("tactic") or None + technique = request.args.get("technique") or None + subtechnique = request.args.get("subtechnique") or None + opsec_level = request.args.get("opsec") or None + tag = request.args.get("tag") or None + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + try: + items, total = svc.list_test_templates( + q=q, + tactic=tactic, + technique=technique, + subtechnique=subtechnique, + opsec_level=opsec_level, + tag=tag, + include_deleted=include_deleted, + limit=limit, + offset=offset, + ) + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + return jsonify( + { + "items": [_serialize(it) for it in items], + "total": total, + "limit": limit, + "offset": offset, + } + ) + + +@bp.get("/") +@require_auth +@require_perm("test_template.read") +def get_test_template(template_id: str): + tid = _parse_uuid_or_400(template_id) + if tid is None: + return jsonify({"error": "invalid_id"}), 400 + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + try: + view = svc.get_test_template(tid, include_deleted=include_deleted) + except svc.TestTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + return jsonify(_serialize(view)) + + +@bp.post("") +@require_auth +@require_perm("test_template.create") +def create_test_template(): + try: + payload = CreateTestTemplatePayload.model_validate(request.get_json(silent=True) or {}) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + try: + view = svc.create_test_template( + name=payload.name, + description=payload.description, + objective=payload.objective, + procedure_md=payload.procedure_md, + prerequisites_md=payload.prerequisites_md, + expected_result_red_md=payload.expected_result_red_md, + expected_detection_blue_md=payload.expected_detection_blue_md, + opsec_level=payload.opsec_level, + tags=payload.tags, + expected_iocs=payload.expected_iocs, + mitre_tags=[svc.MitreTagRef(kind=t.kind, external_id=t.external_id) for t in payload.mitre_tags], + ) + except svc.UnknownMitreTag as e: + return jsonify({"error": "unknown_mitre_tag", "message": str(e)}), 400 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + log.info( + "metamorph.test_template.created", + extra={"id": str(view.id), "template_name": view.name}, + ) + return jsonify(_serialize(view)), 201 + + +@bp.put("/") +@require_auth +@require_perm("test_template.update") +def update_test_template(template_id: str): + tid = _parse_uuid_or_400(template_id) + if tid is None: + return jsonify({"error": "invalid_id"}), 400 + raw = request.get_json(silent=True) or {} + try: + payload = UpdateTestTemplatePayload.model_validate(raw) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + + # Only forward keys actually present in the body — model_validate leaves + # missing fields as None and we can't distinguish "explicitly null" from + # "omitted". The set of keys in `raw` is the wire-level intent. + kwargs: dict[str, Any] = {} + for field_name in ( + "name", "description", "objective", "procedure_md", "prerequisites_md", + "expected_result_red_md", "expected_detection_blue_md", + "opsec_level", "tags", "expected_iocs", + ): + if field_name in raw: + kwargs[field_name] = getattr(payload, field_name) + if "mitre_tags" in raw: + kwargs["mitre_tags"] = ( + [svc.MitreTagRef(kind=t.kind, external_id=t.external_id) for t in (payload.mitre_tags or [])] + ) + try: + view = svc.update_test_template(tid, **kwargs) + except svc.TestTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + except svc.UnknownMitreTag as e: + return jsonify({"error": "unknown_mitre_tag", "message": str(e)}), 400 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + log.info("metamorph.test_template.updated", extra={"id": str(tid), "fields": sorted(kwargs.keys())}) + return jsonify(_serialize(view)) + + +@bp.delete("/") +@require_auth +@require_perm("test_template.delete") +def soft_delete_test_template(template_id: str): + tid = _parse_uuid_or_400(template_id) + if tid is None: + return jsonify({"error": "invalid_id"}), 400 + try: + svc.soft_delete_test_template(tid) + except svc.TestTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + log.info("metamorph.test_template.soft_deleted", extra={"id": str(tid)}) + return jsonify({"ok": True}) diff --git a/backend/app/api/v1.py b/backend/app/api/v1.py index 3423e2d..6b5379b 100644 --- a/backend/app/api/v1.py +++ b/backend/app/api/v1.py @@ -11,7 +11,9 @@ from app.api.health import bp as health_bp from app.api.invitations import bp as invitations_bp from app.api.mitre import bp as mitre_bp from app.api.permissions import bp as permissions_bp +from app.api.scenario_templates import bp as scenario_templates_bp from app.api.setup import bp as setup_bp +from app.api.test_templates import bp as test_templates_bp from app.api.users import bp as users_bp bp = Blueprint("v1", __name__, url_prefix="/api/v1") @@ -24,3 +26,5 @@ bp.register_blueprint(users_bp) bp.register_blueprint(groups_bp) bp.register_blueprint(permissions_bp) bp.register_blueprint(mitre_bp) +bp.register_blueprint(test_templates_bp) +bp.register_blueprint(scenario_templates_bp) diff --git a/backend/app/services/scenario_templates.py b/backend/app/services/scenario_templates.py new file mode 100644 index 0000000..968f7ad --- /dev/null +++ b/backend/app/services/scenario_templates.py @@ -0,0 +1,240 @@ +"""CRUD service for `scenario_templates` + their ordered test list. + +Re-ordering is implemented as **full delete + re-insert** of the +`scenario_template_tests` rows. The UNIQUE (scenario_template_id, position) +constraint makes any naive position-swap fail mid-transaction; wiping the set +then re-inserting at positions 0..N-1 keeps the operation atomic and obvious. + +The same test_template may legitimately appear multiple times in a scenario +(chained operations), so we key on `(scenario_id, position)`, not +`(scenario_id, test_template_id)`. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import func, or_, select +from sqlalchemy.orm import Session, selectinload + +_UNSET: Any = object() + +from app.db.session import session_scope +from app.models.template import ( + ScenarioTemplate, + ScenarioTemplateTest, + TestTemplate, +) + + +class ScenarioTemplateNotFound(Exception): + pass + + +class UnknownTestTemplate(Exception): + """Raised when a scenario references a non-existent or soft-deleted test.""" + + +@dataclass(frozen=True) +class ScenarioTestView: + position: int + test_template_id: uuid.UUID + test_template_name: str + test_template_deleted: bool + + +@dataclass(frozen=True) +class ScenarioTemplateView: + id: uuid.UUID + name: str + description: str | None + tests: list[ScenarioTestView] + tests_count: int + deleted_at: datetime | None + created_at: datetime + updated_at: datetime + + +def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView: + test_ids = [link.test_template_id for link in sc.tests] + name_by_id: dict[uuid.UUID, tuple[str, bool]] = {} + if test_ids: + rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all() + for row in rows: + name_by_id[row.id] = (row.name, row.deleted_at is not None) + tests = [ + ScenarioTestView( + position=link.position, + test_template_id=link.test_template_id, + test_template_name=name_by_id.get(link.test_template_id, ("", True))[0], + test_template_deleted=name_by_id.get(link.test_template_id, ("", True))[1], + ) + for link in sc.tests + ] + return ScenarioTemplateView( + id=sc.id, + name=sc.name, + description=sc.description, + tests=tests, + tests_count=len(tests), + deleted_at=sc.deleted_at, + created_at=sc.created_at, + updated_at=sc.updated_at, + ) + + +def _base_query(): + return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests)) + + +def list_scenario_templates( + *, + q: str | None = None, + include_deleted: bool = False, + limit: int = 100, + offset: int = 0, +) -> tuple[list[ScenarioTemplateView], int]: + with session_scope() as s: + stmt = _base_query().order_by(ScenarioTemplate.name.asc()) + count_stmt = select(func.count()).select_from(ScenarioTemplate) + if not include_deleted: + stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None)) + count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None)) + if q: + like = f"%{q.lower()}%" + cond = or_( + func.lower(ScenarioTemplate.name).like(like), + func.lower(ScenarioTemplate.description).like(like), + ) + stmt = stmt.where(cond) + count_stmt = count_stmt.where(cond) + total = s.scalar(count_stmt) or 0 + rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() + return [_to_view(s, sc) for sc in rows], int(total) + + +def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView: + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None: + raise ScenarioTemplateNotFound() + if sc.deleted_at is not None and not include_deleted: + raise ScenarioTemplateNotFound() + return _to_view(s, sc) + + +def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None: + """Reject unknown or soft-deleted test_template ids before persisting.""" + if not ids: + return + found = s.execute( + select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids)) + ).all() + known = {row.id for row in found} + deleted = {row.id for row in found if row.deleted_at is not None} + missing = set(ids) - known + if missing: + raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}") + if deleted: + raise UnknownTestTemplate( + f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}" + ) + + +def _opt_str(value: str | None) -> str | None: + if value is None: + return None + s = value.strip() + return s or None + + +def create_scenario_template( + *, + name: str, + description: str | None = None, + test_template_ids: list[uuid.UUID] | None = None, +) -> ScenarioTemplateView: + name_norm = (name or "").strip() + if not name_norm: + raise ValueError("name is required") + ids = list(test_template_ids or []) + with session_scope() as s: + _validate_test_ids(s, ids) + sc = ScenarioTemplate( + name=name_norm, + description=_opt_str(description), + ) + s.add(sc) + s.flush() + for position, tid in enumerate(ids): + s.add( + ScenarioTemplateTest( + scenario_template_id=sc.id, + test_template_id=tid, + position=position, + ) + ) + s.flush() + s.refresh(sc) + return _to_view(s, sc) + + +def update_scenario_template( + scenario_id: uuid.UUID, + *, + name: str | None = None, + description: Any = _UNSET, +) -> ScenarioTemplateView: + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None or sc.deleted_at is not None: + raise ScenarioTemplateNotFound() + if name is not None: + n = name.strip() + if not n: + raise ValueError("name cannot be empty") + sc.name = n + if description is not _UNSET: + sc.description = _opt_str(description) + s.flush() + s.refresh(sc) + return _to_view(s, sc) + + +def set_scenario_tests( + scenario_id: uuid.UUID, + test_template_ids: list[uuid.UUID], +) -> ScenarioTemplateView: + """Replace the entire ordered test list. `position` becomes the index.""" + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None or sc.deleted_at is not None: + raise ScenarioTemplateNotFound() + _validate_test_ids(s, test_template_ids) + # Wipe then re-insert. The UNIQUE(position) constraint forbids a + # naive UPDATE-swap; full-replace keeps the op atomic + readable. + for link in list(sc.tests): + s.delete(link) + s.flush() + for position, tid in enumerate(test_template_ids): + s.add( + ScenarioTemplateTest( + scenario_template_id=sc.id, + test_template_id=tid, + position=position, + ) + ) + s.flush() + s.refresh(sc) + return _to_view(s, sc) + + +def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None: + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None or sc.deleted_at is not None: + raise ScenarioTemplateNotFound() + sc.deleted_at = datetime.now(tz=timezone.utc) diff --git a/backend/app/services/test_templates.py b/backend/app/services/test_templates.py new file mode 100644 index 0000000..8570894 --- /dev/null +++ b/backend/app/services/test_templates.py @@ -0,0 +1,395 @@ +"""CRUD service for `test_templates` + their MITRE tags. + +The MITRE tag set is **fully replaced** on every update — partial mutation of +the join rows would force the API client to track tag UUIDs they never created. +The polymorphic join (one of `tactic_id` / `technique_id` / `subtechnique_id` +populated) is owned here: callers pass `(kind, external_id)` tuples and we +resolve them to the matching MITRE row. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Iterable + +from sqlalchemy import func, or_, select +from sqlalchemy.orm import Session, selectinload + +_UNSET: Any = object() + +from app.db.session import session_scope +from app.db.types import MITRE_KINDS, OPSEC_LEVELS +from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique +from app.models.template import TestTemplate, TestTemplateMitreTag + + +class TestTemplateNotFound(Exception): + pass + + +class UnknownMitreTag(Exception): + """Raised when an (kind, external_id) tuple doesn't resolve to a known MITRE row.""" + + +@dataclass(frozen=True) +class MitreTagRef: + """Inbound MITRE tag reference. `external_id` is the ATT&CK identifier + (TA…/T…/T….…) — we resolve it server-side, the client never sees UUIDs. + """ + + kind: str # "tactic" | "technique" | "subtechnique" + external_id: str + + +@dataclass(frozen=True) +class MitreTagView: + kind: str + external_id: str + name: str + url: str | None + + +@dataclass(frozen=True) +class TestTemplateView: + id: uuid.UUID + name: str + description: str | None + objective: str | None + procedure_md: str | None + prerequisites_md: str | None + expected_result_red_md: str | None + expected_detection_blue_md: str | None + opsec_level: str + tags: list[str] + expected_iocs: list[str] + mitre_tags: list[MitreTagView] + deleted_at: datetime | None + created_at: datetime + updated_at: datetime + + +def _validate_opsec(value: str) -> str: + if value not in OPSEC_LEVELS: + raise ValueError(f"opsec_level must be one of {OPSEC_LEVELS}") + return value + + +def _normalize_string_list(values: Iterable[str] | None) -> list[str]: + if not values: + return [] + seen: set[str] = set() + out: list[str] = [] + for raw in values: + if not isinstance(raw, str): + raise ValueError("list items must be strings") + v = raw.strip() + if not v or v in seen: + continue + seen.add(v) + out.append(v) + return out + + +def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplateMitreTag]: + """Translate `(kind, external_id)` pairs into half-populated join rows. + + Validates that: + - `kind` is one of the supported values + - each external_id resolves to an existing MITRE row + - the combination is unique inside the payload (de-duped silently — same + tag twice is a no-op, not an error) + """ + if not refs: + return [] + # Dedupe input + deduped: dict[tuple[str, str], MitreTagRef] = {} + for ref in refs: + if ref.kind not in MITRE_KINDS: + raise ValueError(f"mitre tag kind must be one of {MITRE_KINDS}") + if not ref.external_id: + raise ValueError("mitre tag external_id is required") + deduped[(ref.kind, ref.external_id)] = ref + + tactic_ids = {r.external_id for r in deduped.values() if r.kind == "tactic"} + technique_ids = {r.external_id for r in deduped.values() if r.kind == "technique"} + subtechnique_ids = {r.external_id for r in deduped.values() if r.kind == "subtechnique"} + + tactic_map = { + t.external_id: t.id + for t in s.scalars(select(MitreTactic).where(MitreTactic.external_id.in_(tactic_ids))).all() + } + technique_map = { + t.external_id: t.id + for t in s.scalars(select(MitreTechnique).where(MitreTechnique.external_id.in_(technique_ids))).all() + } + subtechnique_map = { + sb.external_id: sb.id + for sb in s.scalars( + select(MitreSubtechnique).where(MitreSubtechnique.external_id.in_(subtechnique_ids)) + ).all() + } + + rows: list[TestTemplateMitreTag] = [] + missing: list[tuple[str, str]] = [] + for ref in deduped.values(): + if ref.kind == "tactic": + mid = tactic_map.get(ref.external_id) + if mid is None: + missing.append((ref.kind, ref.external_id)) + continue + rows.append(TestTemplateMitreTag(mitre_kind="tactic", tactic_id=mid)) + elif ref.kind == "technique": + mid = technique_map.get(ref.external_id) + if mid is None: + missing.append((ref.kind, ref.external_id)) + continue + rows.append(TestTemplateMitreTag(mitre_kind="technique", technique_id=mid)) + else: + mid = subtechnique_map.get(ref.external_id) + if mid is None: + missing.append((ref.kind, ref.external_id)) + continue + rows.append(TestTemplateMitreTag(mitre_kind="subtechnique", subtechnique_id=mid)) + if missing: + raise UnknownMitreTag(f"unknown MITRE tags: {sorted(missing)}") + return rows + + +def _to_view(s: Session, t: TestTemplate) -> TestTemplateView: + tag_views: list[MitreTagView] = [] + for tag in t.mitre_tags: + if tag.mitre_kind == "tactic" and tag.tactic_id is not None: + row = s.get(MitreTactic, tag.tactic_id) + if row is not None: + tag_views.append(MitreTagView(kind="tactic", external_id=row.external_id, name=row.name, url=row.url)) + elif tag.mitre_kind == "technique" and tag.technique_id is not None: + row = s.get(MitreTechnique, tag.technique_id) + if row is not None: + tag_views.append(MitreTagView(kind="technique", external_id=row.external_id, name=row.name, url=row.url)) + elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None: + row = s.get(MitreSubtechnique, tag.subtechnique_id) + if row is not None: + tag_views.append( + MitreTagView(kind="subtechnique", external_id=row.external_id, name=row.name, url=row.url) + ) + tag_views.sort(key=lambda v: (v.kind, v.external_id)) + return TestTemplateView( + id=t.id, + name=t.name, + description=t.description, + objective=t.objective, + procedure_md=t.procedure_md, + prerequisites_md=t.prerequisites_md, + expected_result_red_md=t.expected_result_red_md, + expected_detection_blue_md=t.expected_detection_blue_md, + opsec_level=t.opsec_level, + tags=list(t.tags or []), + expected_iocs=list(t.expected_iocs or []), + mitre_tags=tag_views, + deleted_at=t.deleted_at, + created_at=t.created_at, + updated_at=t.updated_at, + ) + + +def _base_query(): + return select(TestTemplate).options(selectinload(TestTemplate.mitre_tags)) + + +def list_test_templates( + *, + q: str | None = None, + tactic: str | None = None, # external_id like "TA0006" + technique: str | None = None, + subtechnique: str | None = None, + opsec_level: str | None = None, + tag: str | None = None, + include_deleted: bool = False, + limit: int = 100, + offset: int = 0, +) -> tuple[list[TestTemplateView], int]: + with session_scope() as s: + stmt = _base_query().order_by(TestTemplate.name.asc()) + count_stmt = select(func.count()).select_from(TestTemplate) + if not include_deleted: + stmt = stmt.where(TestTemplate.deleted_at.is_(None)) + count_stmt = count_stmt.where(TestTemplate.deleted_at.is_(None)) + if q: + like = f"%{q.lower()}%" + cond = or_( + func.lower(TestTemplate.name).like(like), + func.lower(TestTemplate.description).like(like), + ) + stmt = stmt.where(cond) + count_stmt = count_stmt.where(cond) + if opsec_level: + _validate_opsec(opsec_level) + stmt = stmt.where(TestTemplate.opsec_level == opsec_level) + count_stmt = count_stmt.where(TestTemplate.opsec_level == opsec_level) + if tag: + stmt = stmt.where(TestTemplate.tags.any(tag)) + count_stmt = count_stmt.where(TestTemplate.tags.any(tag)) + + # MITRE facet: resolve external_id → uuid then filter via join subquery. + if tactic or technique or subtechnique: + tag_ids: list[uuid.UUID] = [] + if tactic: + tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic)) + if tac is None: + return [], 0 + tag_ids.append(tac.id) + if technique: + tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique)) + if tech is None: + return [], 0 + tag_ids.append(tech.id) + if subtechnique: + sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique)) + if sub is None: + return [], 0 + tag_ids.append(sub.id) + sub_q = ( + select(TestTemplateMitreTag.test_template_id) + .where( + or_( + TestTemplateMitreTag.tactic_id.in_(tag_ids), + TestTemplateMitreTag.technique_id.in_(tag_ids), + TestTemplateMitreTag.subtechnique_id.in_(tag_ids), + ) + ) + .distinct() + ) + stmt = stmt.where(TestTemplate.id.in_(sub_q)) + count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) + + total = s.scalar(count_stmt) or 0 + rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() + return [_to_view(s, t) for t in rows], int(total) + + +def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView: + with session_scope() as s: + t = s.get(TestTemplate, template_id) + if t is None: + raise TestTemplateNotFound() + if t.deleted_at is not None and not include_deleted: + raise TestTemplateNotFound() + return _to_view(s, t) + + +def create_test_template( + *, + name: str, + description: str | None = None, + objective: str | None = None, + procedure_md: str | None = None, + prerequisites_md: str | None = None, + expected_result_red_md: str | None = None, + expected_detection_blue_md: str | None = None, + opsec_level: str = "medium", + tags: list[str] | None = None, + expected_iocs: list[str] | None = None, + mitre_tags: list[MitreTagRef] | None = None, +) -> TestTemplateView: + name_norm = (name or "").strip() + if not name_norm: + raise ValueError("name is required") + _validate_opsec(opsec_level) + norm_tags = _normalize_string_list(tags) + norm_iocs = _normalize_string_list(expected_iocs) + with session_scope() as s: + t = TestTemplate( + name=name_norm, + description=_opt_str(description), + objective=_opt_str(objective), + procedure_md=procedure_md or None, + prerequisites_md=prerequisites_md or None, + expected_result_red_md=expected_result_red_md or None, + expected_detection_blue_md=expected_detection_blue_md or None, + opsec_level=opsec_level, + tags=norm_tags, + expected_iocs=norm_iocs, + ) + s.add(t) + s.flush() + if mitre_tags: + rows = _resolve_mitre_refs(s, mitre_tags) + for row in rows: + row.test_template_id = t.id + s.add(row) + s.flush() + s.refresh(t) + return _to_view(s, t) + + +def _opt_str(value: str | None) -> str | None: + if value is None: + return None + s = value.strip() + return s or None + + +def update_test_template( + template_id: uuid.UUID, + *, + name: str | None = None, + description: Any = _UNSET, + objective: Any = _UNSET, + procedure_md: Any = _UNSET, + prerequisites_md: Any = _UNSET, + expected_result_red_md: Any = _UNSET, + expected_detection_blue_md: Any = _UNSET, + opsec_level: str | None = None, + tags: Any = _UNSET, + expected_iocs: Any = _UNSET, + mitre_tags: Any = _UNSET, +) -> TestTemplateView: + with session_scope() as s: + t = s.get(TestTemplate, template_id) + if t is None or t.deleted_at is not None: + raise TestTemplateNotFound() + if name is not None: + n = name.strip() + if not n: + raise ValueError("name cannot be empty") + t.name = n + if description is not _UNSET: + t.description = _opt_str(description) + if objective is not _UNSET: + t.objective = _opt_str(objective) + if procedure_md is not _UNSET: + t.procedure_md = procedure_md or None + if prerequisites_md is not _UNSET: + t.prerequisites_md = prerequisites_md or None + if expected_result_red_md is not _UNSET: + t.expected_result_red_md = expected_result_red_md or None + if expected_detection_blue_md is not _UNSET: + t.expected_detection_blue_md = expected_detection_blue_md or None + if opsec_level is not None: + _validate_opsec(opsec_level) + t.opsec_level = opsec_level + if tags is not _UNSET: + t.tags = _normalize_string_list(tags) + if expected_iocs is not _UNSET: + t.expected_iocs = _normalize_string_list(expected_iocs) + if mitre_tags is not _UNSET: + for row in list(t.mitre_tags): + s.delete(row) + s.flush() + rows = _resolve_mitre_refs(s, list(mitre_tags or [])) + for row in rows: + row.test_template_id = t.id + s.add(row) + s.flush() + s.refresh(t) + return _to_view(s, t) + + +def soft_delete_test_template(template_id: uuid.UUID) -> None: + with session_scope() as s: + t = s.get(TestTemplate, template_id) + if t is None or t.deleted_at is not None: + raise TestTemplateNotFound() + t.deleted_at = datetime.now(tz=timezone.utc) diff --git a/backend/tests/test_templates.py b/backend/tests/test_templates.py new file mode 100644 index 0000000..8b4bc4b --- /dev/null +++ b/backend/tests/test_templates.py @@ -0,0 +1,492 @@ +"""M5 — Template catalogue integration tests. + +Covers `test_template` and `scenario_template` CRUD + ordering + perm gating. +Relies on a minimal MITRE seed (T1059 / TA0001 / T1059.001) so the polymorphic +tag join can be exercised end-to-end. +""" + +from __future__ import annotations + +import json +import secrets + +import pytest +from sqlalchemy import text + +from app.core.install_token import regenerate_install_token +from app.main import create_app +from app.services import mitre_seed as mitre_svc + + +def _truncate_all(engine): + with engine.begin() as conn: + conn.execute( + text( + "TRUNCATE users, refresh_tokens, invitations, invitation_groups, " + "user_groups, group_permissions, permissions, settings, groups, " + "scenario_template_tests, scenario_templates, " + "test_template_mitre_tags, test_templates, " + "mitre_subtechniques, mitre_technique_tactics, mitre_techniques, " + "mitre_tactics RESTART IDENTITY CASCADE" + ) + ) + + +# Same minimal bundle as in test_mitre.py — keeps tag resolution deterministic +# without re-pulling the full enterprise STIX bundle. +_MINIMAL_BUNDLE = { + "type": "bundle", + "id": "bundle--00000000-0000-0000-0000-000000000002", + "spec_version": "2.1", + "objects": [ + { + "type": "x-mitre-tactic", + "id": "x-mitre-tactic--ta0001", + "name": "Initial Access", + "x_mitre_shortname": "initial-access", + "external_references": [ + {"source_name": "mitre-attack", "external_id": "TA0001"} + ], + }, + { + "type": "x-mitre-tactic", + "id": "x-mitre-tactic--ta0002", + "name": "Execution", + "x_mitre_shortname": "execution", + "external_references": [ + {"source_name": "mitre-attack", "external_id": "TA0002"} + ], + }, + { + "type": "attack-pattern", + "id": "attack-pattern--t1059", + "name": "Command and Scripting Interpreter", + "kill_chain_phases": [ + {"kill_chain_name": "mitre-attack", "phase_name": "execution"} + ], + "external_references": [ + {"source_name": "mitre-attack", "external_id": "T1059"} + ], + }, + { + "type": "attack-pattern", + "id": "attack-pattern--t1059-001", + "name": "PowerShell", + "x_mitre_is_subtechnique": True, + "external_references": [ + {"source_name": "mitre-attack", "external_id": "T1059.001"} + ], + }, + { + "type": "relationship", + "id": "relationship--rel1", + "relationship_type": "subtechnique-of", + "source_ref": "attack-pattern--t1059-001", + "target_ref": "attack-pattern--t1059", + }, + ], +} + + +@pytest.fixture(scope="module") +def app(db_engine_or_skip, tmp_path_factory): + _truncate_all(db_engine_or_skip) + bundle_path = tmp_path_factory.mktemp("m5") / "stix.json" + bundle_path.write_text(json.dumps(_MINIMAL_BUNDLE)) + mitre_svc.seed_mitre(source=bundle_path, expected_sha256=None) + flask_app = create_app() + flask_app.config.update(TESTING=True) + return flask_app + + +@pytest.fixture() +def client(app): + return app.test_client() + + +def _unique_email(prefix: str) -> str: + return f"{prefix}-{secrets.token_hex(4)}@metamorph.local" + + +@pytest.fixture(scope="module") +def admin(app): + token = regenerate_install_token() + email = _unique_email("admin") + password = "AdminPass1234!" + with app.test_client() as c: + r = c.post( + "/api/v1/setup", + json={"install_token": token, "email": email, "password": password}, + ) + assert r.status_code == 201, r.get_data(as_text=True) + return {"email": email, "password": password} + + +def _login(client, email: str, password: str) -> str: + r = client.post("/api/v1/auth/login", json={"email": email, "password": password}) + assert r.status_code == 200, r.get_data(as_text=True) + return r.get_json()["access_token"] + + +def _bearer(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture() +def admin_token(client, admin) -> str: + return _login(client, admin["email"], admin["password"]) + + +# === Reader fixture: an invited user with only `test_template.read` ========= + + +def _bootstrap_user_without_perms(client, admin_token: str, prefix: str) -> tuple[str, str]: + email = _unique_email(prefix) + inv = client.post( + "/api/v1/invitations", + headers=_bearer(admin_token), + json={"email_hint": email}, + ) + token = inv.get_json()["token"] + password = "ReaderPass1234!" + client.post( + f"/api/v1/invitations/accept/{token}", + json={"email": email, "password": password}, + ) + return email, _login(client, email, password) + + +# === test_template CRUD ===================================================== + + +def _make_test(client, admin_token: str, **overrides): + body = { + "name": overrides.pop("name", f"Test {secrets.token_hex(2)}"), + "description": overrides.pop("description", "auto"), + "objective": "do thing", + "procedure_md": "1. step", + "expected_result_red_md": "red expectation", + "expected_detection_blue_md": "blue expectation", + "opsec_level": overrides.pop("opsec_level", "medium"), + "tags": overrides.pop("tags", ["fast"]), + "expected_iocs": ["evil.exe"], + "mitre_tags": overrides.pop("mitre_tags", [{"kind": "technique", "external_id": "T1059"}]), + **overrides, + } + r = client.post("/api/v1/test-templates", headers=_bearer(admin_token), json=body) + assert r.status_code == 201, r.get_data(as_text=True) + return r.get_json() + + +def test_create_test_template_with_mitre_tags(client, admin_token): + body = _make_test( + client, + admin_token, + name="PowerShell exec", + mitre_tags=[ + {"kind": "tactic", "external_id": "TA0002"}, + {"kind": "technique", "external_id": "T1059"}, + {"kind": "subtechnique", "external_id": "T1059.001"}, + ], + ) + assert body["opsec_level"] == "medium" + kinds = sorted((t["kind"], t["external_id"]) for t in body["mitre_tags"]) + assert kinds == [ + ("subtechnique", "T1059.001"), + ("tactic", "TA0002"), + ("technique", "T1059"), + ] + + +def test_create_test_template_rejects_unknown_mitre(client, admin_token): + r = client.post( + "/api/v1/test-templates", + headers=_bearer(admin_token), + json={ + "name": "Bad", + "mitre_tags": [{"kind": "technique", "external_id": "T9999"}], + }, + ) + assert r.status_code == 400 + assert r.get_json()["error"] == "unknown_mitre_tag" + + +def test_create_test_template_rejects_bad_opsec(client, admin_token): + r = client.post( + "/api/v1/test-templates", + headers=_bearer(admin_token), + json={"name": "Bad", "opsec_level": "burner"}, + ) + assert r.status_code == 400 + + +def test_list_test_templates_filter_by_tactic(client, admin_token): + _make_test( + client, + admin_token, + name="filterable-1", + mitre_tags=[{"kind": "tactic", "external_id": "TA0002"}], + ) + r = client.get( + "/api/v1/test-templates?tactic=TA0002", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + body = r.get_json() + names = [it["name"] for it in body["items"]] + assert "filterable-1" in names + + +def test_list_test_templates_filter_by_opsec(client, admin_token): + _make_test(client, admin_token, name="high-opsec", opsec_level="high") + r = client.get( + "/api/v1/test-templates?opsec=high", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "high-opsec" in names + assert all(it["opsec_level"] == "high" for it in r.get_json()["items"]) + + +def test_list_test_templates_filter_by_tag(client, admin_token): + _make_test(client, admin_token, name="tagged-fast", tags=["fast", "phish"]) + r = client.get( + "/api/v1/test-templates?tag=phish", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "tagged-fast" in names + + +def test_list_test_templates_search_q(client, admin_token): + _make_test(client, admin_token, name="unique-token-azertyuiop") + r = client.get( + "/api/v1/test-templates?q=AZERTYUIOP", # case-insensitive + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "unique-token-azertyuiop" in names + + +def test_update_test_template_replaces_mitre_tags(client, admin_token): + body = _make_test( + client, + admin_token, + name="to-update", + mitre_tags=[{"kind": "tactic", "external_id": "TA0001"}], + ) + r = client.put( + f"/api/v1/test-templates/{body['id']}", + headers=_bearer(admin_token), + json={"mitre_tags": [{"kind": "technique", "external_id": "T1059"}]}, + ) + assert r.status_code == 200, r.get_data(as_text=True) + updated = r.get_json() + kinds = [(t["kind"], t["external_id"]) for t in updated["mitre_tags"]] + assert kinds == [("technique", "T1059")] + + +def test_update_test_template_partial_keeps_unset_fields(client, admin_token): + body = _make_test( + client, + admin_token, + name="partial-update", + opsec_level="low", + tags=["a", "b"], + ) + r = client.put( + f"/api/v1/test-templates/{body['id']}", + headers=_bearer(admin_token), + json={"name": "renamed"}, + ) + assert r.status_code == 200 + updated = r.get_json() + assert updated["name"] == "renamed" + assert updated["opsec_level"] == "low" # untouched + assert set(updated["tags"]) == {"a", "b"} # untouched + + +def test_soft_delete_then_list_hides_by_default(client, admin_token): + body = _make_test(client, admin_token, name="to-be-deleted") + r = client.delete( + f"/api/v1/test-templates/{body['id']}", headers=_bearer(admin_token) + ) + assert r.status_code == 200 + r2 = client.get("/api/v1/test-templates", headers=_bearer(admin_token)) + names = [it["name"] for it in r2.get_json()["items"]] + assert "to-be-deleted" not in names + # And reappears with include_deleted=true + r3 = client.get( + "/api/v1/test-templates?include_deleted=true", + headers=_bearer(admin_token), + ) + names3 = [it["name"] for it in r3.get_json()["items"]] + assert "to-be-deleted" in names3 + + +def test_read_perm_required(client, admin_token): + """A user without `test_template.read` gets 403.""" + _, eve_token = _bootstrap_user_without_perms(client, admin_token, "eve-noperm") + r = client.get("/api/v1/test-templates", headers=_bearer(eve_token)) + assert r.status_code == 403 + + +def test_write_perm_required(client, admin_token): + """A user with only `test_template.read` cannot create. + + Bootstrap path: create a dedicated group via the admin API, bind only the + `test_template.read` perm, then invite a user pre-assigned to that group. + """ + # 1. Create the read-only group + bind the single perm. + grp = client.post( + "/api/v1/groups", + headers=_bearer(admin_token), + json={"name": f"tpl-reader-{secrets.token_hex(2)}"}, + ).get_json() + r_set = client.put( + f"/api/v1/groups/{grp['id']}/permissions", + headers=_bearer(admin_token), + json={"codes": ["test_template.read"]}, + ) + assert r_set.status_code == 200, r_set.get_data(as_text=True) + + # 2. Invite a user already attached to that group. + email = _unique_email("alice-readonly") + password = "ReaderPass1234!" + inv = client.post( + "/api/v1/invitations", + headers=_bearer(admin_token), + json={"email_hint": email, "group_ids": [grp["id"]]}, + ).get_json() + client.post( + f"/api/v1/invitations/accept/{inv['token']}", + json={"email": email, "password": password}, + ) + token = _login(client, email, password) + + r = client.get("/api/v1/test-templates", headers=_bearer(token)) + assert r.status_code == 200, r.get_data(as_text=True) + r2 = client.post( + "/api/v1/test-templates", headers=_bearer(token), json={"name": "X"} + ) + assert r2.status_code == 403 + + +# === scenario_template CRUD ================================================= + + +def test_create_scenario_with_ordered_tests(client, admin_token): + a = _make_test(client, admin_token, name="scn-a") + b = _make_test(client, admin_token, name="scn-b") + c = _make_test(client, admin_token, name="scn-c") + r = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={ + "name": "phishing-flow", + "description": "click → exec → persist", + "test_template_ids": [a["id"], b["id"], c["id"]], + }, + ) + assert r.status_code == 201, r.get_data(as_text=True) + body = r.get_json() + assert body["tests_count"] == 3 + assert [t["position"] for t in body["tests"]] == [0, 1, 2] + assert [t["test_template_name"] for t in body["tests"]] == ["scn-a", "scn-b", "scn-c"] + + +def test_reorder_scenario_tests(client, admin_token): + a = _make_test(client, admin_token, name="reord-a") + b = _make_test(client, admin_token, name="reord-b") + c = _make_test(client, admin_token, name="reord-c") + created = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={ + "name": "reorder-me", + "test_template_ids": [a["id"], b["id"], c["id"]], + }, + ).get_json() + # Reverse order. + r = client.put( + f"/api/v1/scenario-templates/{created['id']}/tests", + headers=_bearer(admin_token), + json={"test_template_ids": [c["id"], b["id"], a["id"]]}, + ) + assert r.status_code == 200 + after = r.get_json() + assert [t["test_template_name"] for t in after["tests"]] == ["reord-c", "reord-b", "reord-a"] + # Re-reading via GET yields the same order — confirms persistence. + fresh = client.get( + f"/api/v1/scenario-templates/{created['id']}", headers=_bearer(admin_token) + ).get_json() + assert [t["test_template_name"] for t in fresh["tests"]] == ["reord-c", "reord-b", "reord-a"] + + +def test_scenario_rejects_unknown_test_id(client, admin_token): + r = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={ + "name": "bad", + "test_template_ids": ["00000000-0000-0000-0000-000000000000"], + }, + ) + assert r.status_code == 400 + assert r.get_json()["error"] == "unknown_test_template" + + +def test_scenario_rejects_soft_deleted_test_on_create(client, admin_token): + a = _make_test(client, admin_token, name="will-be-deleted") + client.delete(f"/api/v1/test-templates/{a['id']}", headers=_bearer(admin_token)) + r = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={"name": "linked", "test_template_ids": [a["id"]]}, + ) + assert r.status_code == 400 + assert r.get_json()["error"] == "unknown_test_template" + + +def test_scenario_surfaces_soft_deleted_test_after_link(client, admin_token): + """Once linked, a test can be soft-deleted without breaking the scenario — + the join row stays and the API flags the test as deleted.""" + a = _make_test(client, admin_token, name="linked-then-deleted") + sc = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={"name": "survives", "test_template_ids": [a["id"]]}, + ).get_json() + client.delete(f"/api/v1/test-templates/{a['id']}", headers=_bearer(admin_token)) + fresh = client.get( + f"/api/v1/scenario-templates/{sc['id']}", headers=_bearer(admin_token) + ).get_json() + assert fresh["tests"][0]["test_template_deleted"] is True + + +def test_scenario_soft_delete(client, admin_token): + sc = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={"name": "doomed-scn"}, + ).get_json() + r = client.delete( + f"/api/v1/scenario-templates/{sc['id']}", headers=_bearer(admin_token) + ) + assert r.status_code == 200 + names = [ + it["name"] + for it in client.get( + "/api/v1/scenario-templates", headers=_bearer(admin_token) + ).get_json()["items"] + ] + assert "doomed-scn" not in names + + +def test_scenario_perm_required(client, admin_token): + _, eve_token = _bootstrap_user_without_perms(client, admin_token, "scn-eve") + r = client.get("/api/v1/scenario-templates", headers=_bearer(eve_token)) + assert r.status_code == 403 -- 2.49.1 From 2781ce411755a46dd3074a08b8f46e96a2eaa553 Mon Sep 17 00:00:00 2001 From: Knacky Date: Tue, 12 May 2026 19:57:41 +0200 Subject: [PATCH 2/6] feat(m5): admin SPA pages for the template catalogue - AdminTestsPage with filters (q, tactic, opsec, tag), modal-based CRUD, markdown textareas for procedure/result/detection, embedded MitreTagPicker for tagging. - AdminScenariosPage with @dnd-kit/sortable drag-and-drop on the ordered test list, two-step save (PATCH metadata + PUT tests), catalogue picker excluding soft-deleted items. - lib/templates.ts typed client + queryKey factory. - MarkdownField helper (textarea with markdown hint label). - Layout adds Tests + Scenarios admin nav links; App.tsx routes both behind RequireAdmin. Co-Authored-By: Claude Opus 4.7 (1M context) --- frontend/package.json | 3 + frontend/src/App.tsx | 18 + frontend/src/components/Layout.tsx | 4 +- frontend/src/components/MarkdownField.tsx | 45 +++ frontend/src/lib/templates.ts | 136 +++++++ frontend/src/pages/AdminScenariosPage.tsx | 434 ++++++++++++++++++++++ frontend/src/pages/AdminTestsPage.tsx | 403 ++++++++++++++++++++ 7 files changed, 1042 insertions(+), 1 deletion(-) create mode 100644 frontend/src/components/MarkdownField.tsx create mode 100644 frontend/src/lib/templates.ts create mode 100644 frontend/src/pages/AdminScenariosPage.tsx create mode 100644 frontend/src/pages/AdminTestsPage.tsx diff --git a/frontend/package.json b/frontend/package.json index 1f7e91f..2da9cbc 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -13,6 +13,9 @@ "format:check": "prettier --check \"src/**/*.{ts,tsx,css,json,html}\"" }, "dependencies": { + "@dnd-kit/core": "^6.1.0", + "@dnd-kit/sortable": "^8.0.0", + "@dnd-kit/utilities": "^3.2.2", "@fontsource/ibm-plex-sans": "^5.0.20", "@fontsource/jetbrains-mono": "^5.0.20", "@tanstack/react-query": "^5.51.0", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index db48353..44f8731 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -6,6 +6,8 @@ import { RequireAdmin } from '@/components/RequireAdmin'; import { RequireAuth } from '@/components/RequireAuth'; import { AdminGroupsPage } from '@/pages/AdminGroupsPage'; import { AdminInvitationsPage } from '@/pages/AdminInvitationsPage'; +import { AdminScenariosPage } from '@/pages/AdminScenariosPage'; +import { AdminTestsPage } from '@/pages/AdminTestsPage'; import { AdminUsersPage } from '@/pages/AdminUsersPage'; import { HomePage } from '@/pages/HomePage'; import { MitrePage } from '@/pages/MitrePage'; @@ -82,6 +84,22 @@ function App() { } /> + + + + } + /> + + + + } + /> } /> diff --git a/frontend/src/components/Layout.tsx b/frontend/src/components/Layout.tsx index 0eab999..a9b269b 100644 --- a/frontend/src/components/Layout.tsx +++ b/frontend/src/components/Layout.tsx @@ -42,6 +42,8 @@ export function Layout() { {navItem('/admin/users', 'Users')} {navItem('/admin/groups', 'Groups')} {navItem('/admin/invitations', 'Invitations')} + {navItem('/admin/tests', 'Tests')} + {navItem('/admin/scenarios', 'Scenarios')} )} @@ -69,7 +71,7 @@ export function Layout() {
- metamorph · M0 bootstrap · M1 db schema · M2 auth · M3 rbac · M4 mitre · design system from tasks/design.md + metamorph · M0 bootstrap · M1 db schema · M2 auth · M3 rbac · M4 mitre · M5 templates · design system from tasks/design.md
diff --git a/frontend/src/components/MarkdownField.tsx b/frontend/src/components/MarkdownField.tsx new file mode 100644 index 0000000..635793a --- /dev/null +++ b/frontend/src/components/MarkdownField.tsx @@ -0,0 +1,45 @@ +import { useId, type TextareaHTMLAttributes } from 'react'; + +import { cn } from '@/lib/cn'; + +interface MarkdownFieldProps extends Omit, 'value' | 'onChange'> { + label: string; + value: string; + onChange: (next: string) => void; + rows?: number; + hint?: string; +} + +/** + * Markdown-content textarea. We deliberately keep it textarea-only (no fancy + * WYSIWYG editor) — markdown lives well in plain text and the saved blob is + * rendered to HTML at display time (M6/M7 mission pages). The label exposes + * "markdown" so the user knows the field accepts MD syntax. + */ +export function MarkdownField({ label, value, onChange, rows = 6, hint, id, className, ...rest }: MarkdownFieldProps) { + const fallbackId = useId(); + const inputId = id ?? fallbackId; + return ( +
+ +