diff --git a/CHANGELOG.md b/CHANGELOG.md index ea2f1a0..dd72e11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,52 @@ All notable changes to this project will be documented here. Format: [Keep a Cha ## [Unreleased] +### Added — M5 (Test & scenario templates) +- **CRUD `test_templates`** (`app/services/test_templates.py` + `app/api/test_templates.py`): + - Fields: name, description, objective, procedure (markdown), prerequisites (markdown), expected result red, expected detection blue, OPSEC level (`low/medium/high`), free tags (TEXT[]), expected IOCs (TEXT[]). + - Polymorphic MITRE tag set (`(kind, external_id)` ↔ exactly one of `tactic_id`/`technique_id`/`subtechnique_id`). The wire payload uses ATT&CK external IDs — server resolves to UUIDs. + - Filters: `q` (LIKE on name/description), `tactic`/`technique`/`subtechnique` (joined via subquery on the polymorphic tag table), `opsec`, `tag` (array contains). + - REST: `GET /test-templates`, `GET /test-templates/{id}`, `POST /test-templates`, `PUT /test-templates/{id}` (partial, with explicit `_UNSET` sentinel so omitted fields stay untouched), `DELETE /test-templates/{id}` (soft). +- **CRUD `scenario_templates`** (`app/services/scenario_templates.py` + `app/api/scenario_templates.py`): + - Ordered list of test_templates with `position` (UNIQUE `scenario_template_id, position`). + - Reorder via full replace: `PUT /scenario-templates/{id}/tests` deletes the join rows and re-inserts at positions `0..N-1` — clean atomic op that respects the UNIQUE constraint without a 2-phase position shuffle. + - The same test can appear multiple times (chained operations). + - REST: `GET`/`POST`/`PATCH` (metadata) / `DELETE` (soft) on `/scenario-templates`. +- **Frontend**: + - `lib/templates.ts` — typed client + queryKey factory. + - `pages/AdminTestsPage.tsx` — list + filters (q, tactic, opsec, tag) + modal with full field set + embedded `` for tags. + - `pages/AdminScenariosPage.tsx` — list + modal with **@dnd-kit/sortable** vertical drag-and-drop on the ordered test list. New deps: `@dnd-kit/core`, `@dnd-kit/sortable`, `@dnd-kit/utilities`. + - `components/MarkdownField.tsx` — lean textarea with markdown hint (no heavy editor dep; rendering happens at display time in M7). + - Nav adds **Tests** and **Scenarios** links (admin-gated). +- **/diag/reset** truncates the 4 new tables before the MITRE block — the `scenario_template_tests.test_template_id` FK is `ON DELETE RESTRICT`, so the order matters. +- **Testing**: + - `backend/tests/test_templates.py` — **19 pytest** (create/list/filter by tactic+opsec+tag, MITRE tag resolution + replacement on update, soft-delete, perm gating, scenario create+reorder+delete, soft-deleted test linking semantics). + - `e2e/tests/m5-templates.spec.ts` — **4 Playwright** (API CRUD round-trip, scenario reorder, SPA list + opsec filter, SPA scenario list rendering with ordered tests). + - `tasks/testing-m5.md`. + +### Fixed (M5 implementation) +- **`LogRecord` key collision**: `log.info(..., extra={"name": ...})` raises `KeyError("Attempt to overwrite 'name' in LogRecord")` because `name` is reserved by Python's stdlib logging. Renamed to `template_name`. +- **React `currentTarget` null in deferred state updaters**: `onChange={(e) => setX((prev) => ({ ...prev, q: e.currentTarget.value }))}` blanked the page on the first user input because `currentTarget` is cleared after the listener bubble ends, before React invokes the updater. Switched all M5 handlers to `e.target.value`, which persists on the synthetic event. + +### Fixed (post-M5 — scenario reorder 500 + cross-worker lock correctness) +- **`PUT /scenario-templates/{id}/tests` returned 500** (`backend/app/services/scenario_templates.py:218`): the two-argument form `pg_advisory_xact_lock(:n, :m)` failed with `function pg_advisory_xact_lock(smallint, bigint) does not exist`. Postgres only provides `(int4, int4)` and `(bigint)` overloads — psycopg promoted `m = hash(uuid) & 0xFFFFFFFF` (up to 2^32-1) to bigint and there's no matching overload. Switched to the single-argument bigint form with `CAST(:key AS bigint)`. +- **Cross-worker lock was a no-op** (same site): Python's built-in `hash()` is randomised per process via `PYTHONHASHSEED`, so each gunicorn worker computed a different key for the same `scenario_id`, and concurrent reorders on different workers acquired independent locks — defeating the serialisation. Replaced with `blake2b(scenario_id.bytes, digest_size=8)` interpreted as a signed int64. Stable, deterministic, fits in `bigint`. + +### Fixed (post-M5 UI — modal layout for the test-template editor) +- **Modal box capped its width at `max-w-2xl` and had no vertical scroll** (`frontend/src/components/ui/Modal.tsx`): opening **+ New test** rendered the 15-column MITRE matrix inside a 672 px frame with no height cap, so the matrix spilled to the right and the form bottom dropped below the viewport — buttons unreachable, no scroll. Added a `size` prop (default `2xl` for back-compat), `max-h-[calc(100vh-2rem)]` + `flex flex-col` on the dialog, and an inner `min-w-0 flex-1 overflow-y-auto` body so the header stays pinned while the form scrolls inside the modal. +- **MITRE matrix overflow-x failed to scroll inside the modal body** (`frontend/src/components/MitreTagPicker.tsx`): `overflow-x-auto` sat directly on the grid element, but the grid's intrinsic min-width (`15 × minmax(7rem, …)` = 1680 px) prevented it from shrinking below its content, so the grid spilled outside its parent instead of scrolling. Wrapped the grid in a dedicated `overflow-x-auto rounded min-w-0 w-full` scroller and added `min-w-0` to the picker root so the constraint propagates from the modal body. The grid now scrolls horizontally inside the modal. +- **`grid gap-3` form layout in the test-template modal propagated `min-width: auto`** (`frontend/src/pages/AdminTestsPage.tsx`): each grid item refused to shrink below its widest child, so the picker dragged the form (and the body) past the modal width. Switched the form to `flex flex-col gap-3 min-w-0`, which breaks the propagation while preserving vertical spacing. +- **Test-template modal now uses `size="7xl"`** and the scenario-template modal `size="3xl"` to match their content density. + +### Fixed (post-M5 review pass — spec-reviewer + code-reviewer) +- **Filter combinator was OR, not AND** (`backend/app/services/test_templates.py:235`): `?tactic=TA0002&technique=T1059` returned templates matching *either* facet instead of *both*. Pre-fix also pooled all three UUIDs into a shared `IN` list across three columns, theoretically allowing a UUID collision to match across kinds. Refactored to one IN-subquery per facet, ANDed together via repeated `WHERE id IN (...)`. +- **Concurrent reorder race on `set_scenario_tests`** (`backend/app/services/scenario_templates.py:207`): two parallel reorders on the same scenario could deadlock on the `UNIQUE(scenario_id, position)` constraint under READ COMMITTED. Added a per-scenario `pg_advisory_xact_lock(0x5C3, hash(scenario_id))` mirroring the M4 `/mitre/sync` pattern; different scenarios don't contend. +- **N+1 on `_to_view` MITRE resolution** (`backend/app/services/test_templates.py:160`): rendering K templates with ~T tags each fired up to K×T `s.get(...)` calls. Added `_to_views_batch` that pre-builds `{uuid → MitreRow}` maps in 3 queries and feeds them to per-template view assembly; `list_test_templates` now issues 4 queries total regardless of list size. +- **Wire-level item length cap on `tags` / `expected_iocs`** (`backend/app/api/test_templates.py:18-21`): the DB columns are `ARRAY(String(64))` / `ARRAY(String(255))` but the API layer only capped the LIST length, not item strings — long inputs hit the driver with `StringDataRightTruncation`. Added `Annotated[str, StringConstraints(...)]` types so the API returns 400 with a clean validation error. +- **Front-end mutation cache hygiene** (`frontend/src/pages/AdminScenariosPage.tsx:148-156`): `updateMeta` and `setTests` mutations are run sequentially in `submit()`; on partial failure (metadata saved but reorder failed) the cache stayed stale. Both mutations now `onSettled: invalidate` so whatever step landed is reflected without manual refresh. +- **Backend vs front-end consistency on duplicate tests in a scenario** (`frontend/src/pages/AdminScenariosPage.tsx:227-231`): the backend allows the same `test_template` to appear multiple times (chained ops; the UNIQUE constraint is `(scenario_id, position)` not `(scenario_id, test_template_id)`), but the catalogue picker was filtering out already-picked items. Removed the filter — only soft-deleted tests are excluded now. +- **Test coverage closure** (`backend/tests/test_templates.py`): +4 pytest (tactic+technique AND-semantics, `extra="forbid"` rejection, empty `mitre_tags` explicit clear, 65-char tag length cap → 400). Total backend now 23 M5 tests + 39 elsewhere = 81 pass. + ### Added — M4 (MITRE ATT&CK Enterprise) - **STIX 2.1 parser + upsert** (`app/services/mitre_seed.py`): stdlib-only (`urllib.request` + `hashlib`), pinned to Enterprise v19.0 (`enterprise-attack-19.0.json`, sha256 `df520ea0…`). Parses 25k+ STIX objects → 15 tactics, 222 techniques, 475 sub-techniques in ~1.1 s. Skips revoked + deprecated, resolves sub-technique parents via `relationship[subtechnique-of]` with a `T1003.001 → T1003` dotted-id fallback, copies kill-chain phases into the `mitre_technique_tactics` M2M. - **CLI**: `flask metamorph seed-mitre [--source ] [--checksum-sha256 ] [--skip-checksum]` (`app/cli.py`). `make seed-mitre` wraps it. diff --git a/README.md b/README.md index da324e7..59df5d2 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Collaborative purple-team platform. Red team logs the tests they execute (procedure, command, timestamp); blue team annotates each test with detection evidence (alerts, logs, files). At the end of an engagement, Metamorph generates a standalone reveal.js slide deck classified by MITRE ATT&CK tactic. -> **Status**: M0–M4 delivered (bootstrap → DB schema → auth → RBAC → MITRE ATT&CK reference). See `tasks/spec.md` for the full specification and `tasks/todo.md` for the milestone-by-milestone plan. +> **Status**: M0–M5 delivered (bootstrap → DB schema → auth → RBAC → MITRE ATT&CK reference → test & scenario templates). See `tasks/spec.md` for the full specification and `tasks/todo.md` for the milestone-by-milestone plan. ## Stack @@ -11,6 +11,7 @@ Collaborative purple-team platform. Red team logs the tests they execute (proced - **Auth (M2+)**: JWT access (1h) + refresh (30d), Argon2id, invite-link enrollment. - **RBAC (M3+)**: atomic permissions (31 codes) bundled into custom groups; 3 system groups seeded (`admin` / `redteam` / `blueteam`). - **MITRE ATT&CK (M4+)**: Enterprise reference catalogue pinned to v19.0, seedable via `make seed-mitre`. +- **Template catalogue (M5+)**: reusable `test_templates` (markdown procedure, OPSEC level, free tags, expected IOCs, MITRE tags) + ordered `scenario_templates` with drag-and-drop reordering. Admin pages at `/admin/tests` and `/admin/scenarios`. - **Delivery**: docker-compose. TLS termination is expected to be handled by an external reverse proxy in production. ## Quickstart diff --git a/backend/app/api/diag.py b/backend/app/api/diag.py index 4fa3880..8dd0e20 100644 --- a/backend/app/api/diag.py +++ b/backend/app/api/diag.py @@ -73,6 +73,17 @@ def reset_test_state(): "user_groups, settings, groups RESTART IDENTITY CASCADE" ) ) + # Template catalogue reset (M5). The MITRE truncate below cascades to + # the polymorphic tag join, but the template rows themselves must be + # wiped first because `scenario_template_tests.test_template_id` is + # ON DELETE RESTRICT. + conn.execute( + text( + "TRUNCATE scenario_template_tests, scenario_templates, " + "test_template_mitre_tags, test_templates " + "RESTART IDENTITY CASCADE" + ) + ) # MITRE reference reset — kept in sync with `settings` so a freshly # reset stack has `GET /mitre/status` and `GET /mitre/tactics` agree # ("no data, no last_sync"). The e2e suite re-syncs via /mitre/sync diff --git a/backend/app/api/scenario_templates.py b/backend/app/api/scenario_templates.py new file mode 100644 index 0000000..4481f5d --- /dev/null +++ b/backend/app/api/scenario_templates.py @@ -0,0 +1,208 @@ +"""Scenario-template CRUD + reorder endpoints. + +`PUT //tests` is the reorder/replace endpoint — it takes the full ordered +list and rewrites the join rows. There's no partial mutation API for the test +list: the wire contract is simpler and the client (drag-and-drop) already +holds the full ordering. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from flask import Blueprint, jsonify, request +from pydantic import BaseModel, Field, ValidationError + +from app.core.auth_decorators import require_auth, require_perm +from app.services import scenario_templates as svc + +bp = Blueprint("scenario_templates", __name__, url_prefix="/scenario-templates") +log = logging.getLogger("metamorph.api.scenario_templates") + + +class CreateScenarioPayload(BaseModel): + name: str = Field(min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + test_template_ids: list[uuid.UUID] = Field(default_factory=list, max_length=512) + + model_config = {"extra": "forbid"} + + +class UpdateScenarioPayload(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + + model_config = {"extra": "forbid"} + + +class SetTestsPayload(BaseModel): + test_template_ids: list[uuid.UUID] = Field(default_factory=list, max_length=512) + + model_config = {"extra": "forbid"} + + +def _serialize(sc: svc.ScenarioTemplateView) -> dict[str, Any]: + return { + "id": str(sc.id), + "name": sc.name, + "description": sc.description, + "tests": [ + { + "position": t.position, + "test_template_id": str(t.test_template_id), + "test_template_name": t.test_template_name, + "test_template_deleted": t.test_template_deleted, + } + for t in sc.tests + ], + "tests_count": sc.tests_count, + "deleted_at": sc.deleted_at.isoformat() if sc.deleted_at else None, + "created_at": sc.created_at.isoformat(), + "updated_at": sc.updated_at.isoformat(), + } + + +def _parse_uuid_or_400(raw: str): + try: + return uuid.UUID(raw) + except ValueError: + return None + + +def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]: + try: + limit = int(request.args.get("limit", "100")) + offset = int(request.args.get("offset", "0")) + except ValueError: + return None, (400, "invalid_pagination") + return max(1, min(limit, 500)), max(0, offset) + + +@bp.get("") +@require_auth +@require_perm("scenario_template.read") +def list_scenario_templates(): + paging = _pagination_args() + if paging[0] is None: + return jsonify({"error": paging[1][1]}), paging[1][0] + limit, offset = paging + q = request.args.get("q") or None + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + items, total = svc.list_scenario_templates( + q=q, include_deleted=include_deleted, limit=limit, offset=offset + ) + return jsonify( + { + "items": [_serialize(it) for it in items], + "total": total, + "limit": limit, + "offset": offset, + } + ) + + +@bp.get("/") +@require_auth +@require_perm("scenario_template.read") +def get_scenario_template(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + try: + view = svc.get_scenario_template(sid, include_deleted=include_deleted) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + return jsonify(_serialize(view)) + + +@bp.post("") +@require_auth +@require_perm("scenario_template.create") +def create_scenario_template(): + try: + payload = CreateScenarioPayload.model_validate(request.get_json(silent=True) or {}) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + try: + view = svc.create_scenario_template( + name=payload.name, + description=payload.description, + test_template_ids=payload.test_template_ids, + ) + except svc.UnknownTestTemplate as e: + return jsonify({"error": "unknown_test_template", "message": str(e)}), 400 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + log.info( + "metamorph.scenario_template.created", + extra={"id": str(view.id), "tests": len(view.tests)}, + ) + return jsonify(_serialize(view)), 201 + + +@bp.patch("/") +@require_auth +@require_perm("scenario_template.update") +def update_scenario_template(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + raw = request.get_json(silent=True) or {} + try: + payload = UpdateScenarioPayload.model_validate(raw) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + kwargs: dict[str, Any] = {} + if "name" in raw: + kwargs["name"] = payload.name + if "description" in raw: + kwargs["description"] = payload.description + try: + view = svc.update_scenario_template(sid, **kwargs) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + return jsonify(_serialize(view)) + + +@bp.put("//tests") +@require_auth +@require_perm("scenario_template.update") +def set_scenario_tests(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + try: + payload = SetTestsPayload.model_validate(request.get_json(silent=True) or {}) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + try: + view = svc.set_scenario_tests(sid, payload.test_template_ids) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + except svc.UnknownTestTemplate as e: + return jsonify({"error": "unknown_test_template", "message": str(e)}), 400 + log.info( + "metamorph.scenario_template.tests_set", + extra={"id": str(sid), "tests": len(view.tests)}, + ) + return jsonify(_serialize(view)) + + +@bp.delete("/") +@require_auth +@require_perm("scenario_template.delete") +def soft_delete_scenario_template(scenario_id: str): + sid = _parse_uuid_or_400(scenario_id) + if sid is None: + return jsonify({"error": "invalid_id"}), 400 + try: + svc.soft_delete_scenario_template(sid) + except svc.ScenarioTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + log.info("metamorph.scenario_template.soft_deleted", extra={"id": str(sid)}) + return jsonify({"ok": True}) diff --git a/backend/app/api/test_templates.py b/backend/app/api/test_templates.py new file mode 100644 index 0000000..8f18b55 --- /dev/null +++ b/backend/app/api/test_templates.py @@ -0,0 +1,257 @@ +"""Test-template CRUD endpoints. + +Reads gated by `test_template.read`. Writes gated by `test_template.{create, +update,delete}`. Service layer handles all DB work; this module only validates +the wire payload and shapes the JSON response. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from flask import Blueprint, jsonify, request +from pydantic import BaseModel, Field, StringConstraints, ValidationError +from typing import Annotated + +from app.core.auth_decorators import require_auth, require_perm +from app.services import test_templates as svc + +# Tag and IOC entries are stored as PG ARRAY(String(N)). Cap items at the wire +# layer so over-sized inputs return 400 with a useful message rather than the +# bare StringDataRightTruncation from the driver. +TagStr = Annotated[str, StringConstraints(min_length=1, max_length=64)] +IocStr = Annotated[str, StringConstraints(min_length=1, max_length=255)] + +bp = Blueprint("test_templates", __name__, url_prefix="/test-templates") +log = logging.getLogger("metamorph.api.test_templates") + + +# === Payload schemas ========================================================== + + +class MitreTagIn(BaseModel): + kind: str = Field(min_length=1) + external_id: str = Field(min_length=1, max_length=16) + + model_config = {"extra": "forbid"} + + +class CreateTestTemplatePayload(BaseModel): + name: str = Field(min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + objective: str | None = Field(default=None, max_length=4000) + procedure_md: str | None = Field(default=None, max_length=32_000) + prerequisites_md: str | None = Field(default=None, max_length=32_000) + expected_result_red_md: str | None = Field(default=None, max_length=32_000) + expected_detection_blue_md: str | None = Field(default=None, max_length=32_000) + opsec_level: str = Field(default="medium") + tags: list[TagStr] = Field(default_factory=list, max_length=64) + expected_iocs: list[IocStr] = Field(default_factory=list, max_length=128) + mitre_tags: list[MitreTagIn] = Field(default_factory=list, max_length=64) + + model_config = {"extra": "forbid"} + + +class UpdateTestTemplatePayload(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=4000) + objective: str | None = Field(default=None, max_length=4000) + procedure_md: str | None = Field(default=None, max_length=32_000) + prerequisites_md: str | None = Field(default=None, max_length=32_000) + expected_result_red_md: str | None = Field(default=None, max_length=32_000) + expected_detection_blue_md: str | None = Field(default=None, max_length=32_000) + opsec_level: str | None = None + tags: list[TagStr] | None = Field(default=None, max_length=64) + expected_iocs: list[IocStr] | None = Field(default=None, max_length=128) + mitre_tags: list[MitreTagIn] | None = Field(default=None, max_length=64) + + model_config = {"extra": "forbid"} + + +# === Serializers ============================================================== + + +def _serialize(t: svc.TestTemplateView) -> dict[str, Any]: + return { + "id": str(t.id), + "name": t.name, + "description": t.description, + "objective": t.objective, + "procedure_md": t.procedure_md, + "prerequisites_md": t.prerequisites_md, + "expected_result_red_md": t.expected_result_red_md, + "expected_detection_blue_md": t.expected_detection_blue_md, + "opsec_level": t.opsec_level, + "tags": list(t.tags), + "expected_iocs": list(t.expected_iocs), + "mitre_tags": [ + {"kind": tag.kind, "external_id": tag.external_id, "name": tag.name, "url": tag.url} + for tag in t.mitre_tags + ], + "deleted_at": t.deleted_at.isoformat() if t.deleted_at else None, + "created_at": t.created_at.isoformat(), + "updated_at": t.updated_at.isoformat(), + } + + +def _parse_uuid_or_400(raw: str): + try: + return uuid.UUID(raw) + except ValueError: + return None + + +def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]: + try: + limit = int(request.args.get("limit", "100")) + offset = int(request.args.get("offset", "0")) + except ValueError: + return None, (400, "invalid_pagination") + return max(1, min(limit, 500)), max(0, offset) + + +# === Endpoints ================================================================ + + +@bp.get("") +@require_auth +@require_perm("test_template.read") +def list_test_templates(): + paging = _pagination_args() + if paging[0] is None: + return jsonify({"error": paging[1][1]}), paging[1][0] + limit, offset = paging + q = request.args.get("q") or None + tactic = request.args.get("tactic") or None + technique = request.args.get("technique") or None + subtechnique = request.args.get("subtechnique") or None + opsec_level = request.args.get("opsec") or None + tag = request.args.get("tag") or None + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + try: + items, total = svc.list_test_templates( + q=q, + tactic=tactic, + technique=technique, + subtechnique=subtechnique, + opsec_level=opsec_level, + tag=tag, + include_deleted=include_deleted, + limit=limit, + offset=offset, + ) + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + return jsonify( + { + "items": [_serialize(it) for it in items], + "total": total, + "limit": limit, + "offset": offset, + } + ) + + +@bp.get("/") +@require_auth +@require_perm("test_template.read") +def get_test_template(template_id: str): + tid = _parse_uuid_or_400(template_id) + if tid is None: + return jsonify({"error": "invalid_id"}), 400 + include_deleted = request.args.get("include_deleted", "false").lower() == "true" + try: + view = svc.get_test_template(tid, include_deleted=include_deleted) + except svc.TestTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + return jsonify(_serialize(view)) + + +@bp.post("") +@require_auth +@require_perm("test_template.create") +def create_test_template(): + try: + payload = CreateTestTemplatePayload.model_validate(request.get_json(silent=True) or {}) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + try: + view = svc.create_test_template( + name=payload.name, + description=payload.description, + objective=payload.objective, + procedure_md=payload.procedure_md, + prerequisites_md=payload.prerequisites_md, + expected_result_red_md=payload.expected_result_red_md, + expected_detection_blue_md=payload.expected_detection_blue_md, + opsec_level=payload.opsec_level, + tags=payload.tags, + expected_iocs=payload.expected_iocs, + mitre_tags=[svc.MitreTagRef(kind=t.kind, external_id=t.external_id) for t in payload.mitre_tags], + ) + except svc.UnknownMitreTag as e: + return jsonify({"error": "unknown_mitre_tag", "message": str(e)}), 400 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + log.info( + "metamorph.test_template.created", + extra={"id": str(view.id), "template_name": view.name}, + ) + return jsonify(_serialize(view)), 201 + + +@bp.put("/") +@require_auth +@require_perm("test_template.update") +def update_test_template(template_id: str): + tid = _parse_uuid_or_400(template_id) + if tid is None: + return jsonify({"error": "invalid_id"}), 400 + raw = request.get_json(silent=True) or {} + try: + payload = UpdateTestTemplatePayload.model_validate(raw) + except ValidationError as e: + return jsonify({"error": "invalid_request", "details": e.errors()}), 400 + + # Only forward keys actually present in the body — model_validate leaves + # missing fields as None and we can't distinguish "explicitly null" from + # "omitted". The set of keys in `raw` is the wire-level intent. + kwargs: dict[str, Any] = {} + for field_name in ( + "name", "description", "objective", "procedure_md", "prerequisites_md", + "expected_result_red_md", "expected_detection_blue_md", + "opsec_level", "tags", "expected_iocs", + ): + if field_name in raw: + kwargs[field_name] = getattr(payload, field_name) + if "mitre_tags" in raw: + kwargs["mitre_tags"] = ( + [svc.MitreTagRef(kind=t.kind, external_id=t.external_id) for t in (payload.mitre_tags or [])] + ) + try: + view = svc.update_test_template(tid, **kwargs) + except svc.TestTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + except svc.UnknownMitreTag as e: + return jsonify({"error": "unknown_mitre_tag", "message": str(e)}), 400 + except ValueError as e: + return jsonify({"error": "invalid_request", "message": str(e)}), 400 + log.info("metamorph.test_template.updated", extra={"id": str(tid), "fields": sorted(kwargs.keys())}) + return jsonify(_serialize(view)) + + +@bp.delete("/") +@require_auth +@require_perm("test_template.delete") +def soft_delete_test_template(template_id: str): + tid = _parse_uuid_or_400(template_id) + if tid is None: + return jsonify({"error": "invalid_id"}), 400 + try: + svc.soft_delete_test_template(tid) + except svc.TestTemplateNotFound: + return jsonify({"error": "not_found"}), 404 + log.info("metamorph.test_template.soft_deleted", extra={"id": str(tid)}) + return jsonify({"ok": True}) diff --git a/backend/app/api/v1.py b/backend/app/api/v1.py index 3423e2d..6b5379b 100644 --- a/backend/app/api/v1.py +++ b/backend/app/api/v1.py @@ -11,7 +11,9 @@ from app.api.health import bp as health_bp from app.api.invitations import bp as invitations_bp from app.api.mitre import bp as mitre_bp from app.api.permissions import bp as permissions_bp +from app.api.scenario_templates import bp as scenario_templates_bp from app.api.setup import bp as setup_bp +from app.api.test_templates import bp as test_templates_bp from app.api.users import bp as users_bp bp = Blueprint("v1", __name__, url_prefix="/api/v1") @@ -24,3 +26,5 @@ bp.register_blueprint(users_bp) bp.register_blueprint(groups_bp) bp.register_blueprint(permissions_bp) bp.register_blueprint(mitre_bp) +bp.register_blueprint(test_templates_bp) +bp.register_blueprint(scenario_templates_bp) diff --git a/backend/app/services/scenario_templates.py b/backend/app/services/scenario_templates.py new file mode 100644 index 0000000..258a023 --- /dev/null +++ b/backend/app/services/scenario_templates.py @@ -0,0 +1,259 @@ +"""CRUD service for `scenario_templates` + their ordered test list. + +Re-ordering is implemented as **full delete + re-insert** of the +`scenario_template_tests` rows. The UNIQUE (scenario_template_id, position) +constraint makes any naive position-swap fail mid-transaction; wiping the set +then re-inserting at positions 0..N-1 keeps the operation atomic and obvious. + +The same test_template may legitimately appear multiple times in a scenario +(chained operations), so we key on `(scenario_id, position)`, not +`(scenario_id, test_template_id)`. +""" + +from __future__ import annotations + +import hashlib +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import func, or_, select, text +from sqlalchemy.orm import Session, selectinload + +_UNSET: Any = object() + +from app.db.session import session_scope +from app.models.template import ( + ScenarioTemplate, + ScenarioTemplateTest, + TestTemplate, +) + + +class ScenarioTemplateNotFound(Exception): + pass + + +class UnknownTestTemplate(Exception): + """Raised when a scenario references a non-existent or soft-deleted test.""" + + +@dataclass(frozen=True) +class ScenarioTestView: + position: int + test_template_id: uuid.UUID + test_template_name: str + test_template_deleted: bool + + +@dataclass(frozen=True) +class ScenarioTemplateView: + id: uuid.UUID + name: str + description: str | None + tests: list[ScenarioTestView] + tests_count: int + deleted_at: datetime | None + created_at: datetime + updated_at: datetime + + +def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView: + test_ids = [link.test_template_id for link in sc.tests] + name_by_id: dict[uuid.UUID, tuple[str, bool]] = {} + if test_ids: + rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all() + for row in rows: + name_by_id[row.id] = (row.name, row.deleted_at is not None) + tests = [ + ScenarioTestView( + position=link.position, + test_template_id=link.test_template_id, + test_template_name=name_by_id.get(link.test_template_id, ("", True))[0], + test_template_deleted=name_by_id.get(link.test_template_id, ("", True))[1], + ) + for link in sc.tests + ] + return ScenarioTemplateView( + id=sc.id, + name=sc.name, + description=sc.description, + tests=tests, + tests_count=len(tests), + deleted_at=sc.deleted_at, + created_at=sc.created_at, + updated_at=sc.updated_at, + ) + + +def _base_query(): + return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests)) + + +def list_scenario_templates( + *, + q: str | None = None, + include_deleted: bool = False, + limit: int = 100, + offset: int = 0, +) -> tuple[list[ScenarioTemplateView], int]: + with session_scope() as s: + stmt = _base_query().order_by(ScenarioTemplate.name.asc()) + count_stmt = select(func.count()).select_from(ScenarioTemplate) + if not include_deleted: + stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None)) + count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None)) + if q: + like = f"%{q.lower()}%" + cond = or_( + func.lower(ScenarioTemplate.name).like(like), + func.lower(ScenarioTemplate.description).like(like), + ) + stmt = stmt.where(cond) + count_stmt = count_stmt.where(cond) + total = s.scalar(count_stmt) or 0 + rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() + return [_to_view(s, sc) for sc in rows], int(total) + + +def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView: + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None: + raise ScenarioTemplateNotFound() + if sc.deleted_at is not None and not include_deleted: + raise ScenarioTemplateNotFound() + return _to_view(s, sc) + + +def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None: + """Reject unknown or soft-deleted test_template ids before persisting.""" + if not ids: + return + found = s.execute( + select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids)) + ).all() + known = {row.id for row in found} + deleted = {row.id for row in found if row.deleted_at is not None} + missing = set(ids) - known + if missing: + raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}") + if deleted: + raise UnknownTestTemplate( + f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}" + ) + + +def _opt_str(value: str | None) -> str | None: + if value is None: + return None + s = value.strip() + return s or None + + +def create_scenario_template( + *, + name: str, + description: str | None = None, + test_template_ids: list[uuid.UUID] | None = None, +) -> ScenarioTemplateView: + name_norm = (name or "").strip() + if not name_norm: + raise ValueError("name is required") + ids = list(test_template_ids or []) + with session_scope() as s: + _validate_test_ids(s, ids) + sc = ScenarioTemplate( + name=name_norm, + description=_opt_str(description), + ) + s.add(sc) + s.flush() + for position, tid in enumerate(ids): + s.add( + ScenarioTemplateTest( + scenario_template_id=sc.id, + test_template_id=tid, + position=position, + ) + ) + s.flush() + s.refresh(sc) + return _to_view(s, sc) + + +def update_scenario_template( + scenario_id: uuid.UUID, + *, + name: str | None = None, + description: Any = _UNSET, +) -> ScenarioTemplateView: + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None or sc.deleted_at is not None: + raise ScenarioTemplateNotFound() + if name is not None: + n = name.strip() + if not n: + raise ValueError("name cannot be empty") + sc.name = n + if description is not _UNSET: + sc.description = _opt_str(description) + s.flush() + s.refresh(sc) + return _to_view(s, sc) + + +def set_scenario_tests( + scenario_id: uuid.UUID, + test_template_ids: list[uuid.UUID], +) -> ScenarioTemplateView: + """Replace the entire ordered test list. `position` becomes the index. + + Acquires a per-scenario advisory lock to serialise concurrent reorders. + Without it, two parallel `PUT /scenario-templates/{id}/tests` calls would + race on the wipe-then-insert sequence and deadlock on the UNIQUE(position) + constraint under READ COMMITTED. Mirrors the M4 pattern on /mitre/sync. + """ + with session_scope() as s: + # Lock keyed on the scenario UUID — different scenarios don't block + # each other. Single bigint form so we don't have to juggle int32 + # signed ranges. blake2b is used instead of Python's built-in hash() + # because the latter is randomised per-process (PYTHONHASHSEED), so + # two gunicorn workers would compute different keys for the same + # scenario and the lock wouldn't serialise across them. + digest = hashlib.blake2b(scenario_id.bytes, digest_size=8).digest() + lock_key = int.from_bytes(digest, "big", signed=True) + s.execute( + text("SELECT pg_advisory_xact_lock(CAST(:key AS bigint))"), + {"key": lock_key}, + ) + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None or sc.deleted_at is not None: + raise ScenarioTemplateNotFound() + _validate_test_ids(s, test_template_ids) + # Wipe then re-insert. The UNIQUE(position) constraint forbids a + # naive UPDATE-swap; full-replace keeps the op atomic + readable. + for link in list(sc.tests): + s.delete(link) + s.flush() + for position, tid in enumerate(test_template_ids): + s.add( + ScenarioTemplateTest( + scenario_template_id=sc.id, + test_template_id=tid, + position=position, + ) + ) + s.flush() + s.refresh(sc) + return _to_view(s, sc) + + +def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None: + with session_scope() as s: + sc = s.get(ScenarioTemplate, scenario_id) + if sc is None or sc.deleted_at is not None: + raise ScenarioTemplateNotFound() + sc.deleted_at = datetime.now(tz=timezone.utc) diff --git a/backend/app/services/test_templates.py b/backend/app/services/test_templates.py new file mode 100644 index 0000000..8702e78 --- /dev/null +++ b/backend/app/services/test_templates.py @@ -0,0 +1,495 @@ +"""CRUD service for `test_templates` + their MITRE tags. + +The MITRE tag set is **fully replaced** on every update — partial mutation of +the join rows would force the API client to track tag UUIDs they never created. +The polymorphic join (one of `tactic_id` / `technique_id` / `subtechnique_id` +populated) is owned here: callers pass `(kind, external_id)` tuples and we +resolve them to the matching MITRE row. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Iterable + +from sqlalchemy import func, or_, select +from sqlalchemy.orm import Session, selectinload + +_UNSET: Any = object() + +from app.db.session import session_scope +from app.db.types import MITRE_KINDS, OPSEC_LEVELS +from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique +from app.models.template import TestTemplate, TestTemplateMitreTag + + +class TestTemplateNotFound(Exception): + pass + + +class UnknownMitreTag(Exception): + """Raised when an (kind, external_id) tuple doesn't resolve to a known MITRE row.""" + + +@dataclass(frozen=True) +class MitreTagRef: + """Inbound MITRE tag reference. `external_id` is the ATT&CK identifier + (TA…/T…/T….…) — we resolve it server-side, the client never sees UUIDs. + """ + + kind: str # "tactic" | "technique" | "subtechnique" + external_id: str + + +@dataclass(frozen=True) +class MitreTagView: + kind: str + external_id: str + name: str + url: str | None + + +@dataclass(frozen=True) +class TestTemplateView: + id: uuid.UUID + name: str + description: str | None + objective: str | None + procedure_md: str | None + prerequisites_md: str | None + expected_result_red_md: str | None + expected_detection_blue_md: str | None + opsec_level: str + tags: list[str] + expected_iocs: list[str] + mitre_tags: list[MitreTagView] + deleted_at: datetime | None + created_at: datetime + updated_at: datetime + + +def _validate_opsec(value: str) -> str: + if value not in OPSEC_LEVELS: + raise ValueError(f"opsec_level must be one of {OPSEC_LEVELS}") + return value + + +def _normalize_string_list(values: Iterable[str] | None) -> list[str]: + if not values: + return [] + seen: set[str] = set() + out: list[str] = [] + for raw in values: + if not isinstance(raw, str): + raise ValueError("list items must be strings") + v = raw.strip() + if not v or v in seen: + continue + seen.add(v) + out.append(v) + return out + + +def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplateMitreTag]: + """Translate `(kind, external_id)` pairs into half-populated join rows. + + Validates that: + - `kind` is one of the supported values + - each external_id resolves to an existing MITRE row + - the combination is unique inside the payload (de-duped silently — same + tag twice is a no-op, not an error) + """ + if not refs: + return [] + # Dedupe input + deduped: dict[tuple[str, str], MitreTagRef] = {} + for ref in refs: + if ref.kind not in MITRE_KINDS: + raise ValueError(f"mitre tag kind must be one of {MITRE_KINDS}") + if not ref.external_id: + raise ValueError("mitre tag external_id is required") + deduped[(ref.kind, ref.external_id)] = ref + + tactic_ids = {r.external_id for r in deduped.values() if r.kind == "tactic"} + technique_ids = {r.external_id for r in deduped.values() if r.kind == "technique"} + subtechnique_ids = {r.external_id for r in deduped.values() if r.kind == "subtechnique"} + + tactic_map = { + t.external_id: t.id + for t in s.scalars(select(MitreTactic).where(MitreTactic.external_id.in_(tactic_ids))).all() + } + technique_map = { + t.external_id: t.id + for t in s.scalars(select(MitreTechnique).where(MitreTechnique.external_id.in_(technique_ids))).all() + } + subtechnique_map = { + sb.external_id: sb.id + for sb in s.scalars( + select(MitreSubtechnique).where(MitreSubtechnique.external_id.in_(subtechnique_ids)) + ).all() + } + + rows: list[TestTemplateMitreTag] = [] + missing: list[tuple[str, str]] = [] + for ref in deduped.values(): + if ref.kind == "tactic": + mid = tactic_map.get(ref.external_id) + if mid is None: + missing.append((ref.kind, ref.external_id)) + continue + rows.append(TestTemplateMitreTag(mitre_kind="tactic", tactic_id=mid)) + elif ref.kind == "technique": + mid = technique_map.get(ref.external_id) + if mid is None: + missing.append((ref.kind, ref.external_id)) + continue + rows.append(TestTemplateMitreTag(mitre_kind="technique", technique_id=mid)) + else: + mid = subtechnique_map.get(ref.external_id) + if mid is None: + missing.append((ref.kind, ref.external_id)) + continue + rows.append(TestTemplateMitreTag(mitre_kind="subtechnique", subtechnique_id=mid)) + if missing: + raise UnknownMitreTag(f"unknown MITRE tags: {sorted(missing)}") + return rows + + +def _resolve_mitre_views(s: Session, tags: list[TestTemplateMitreTag]) -> list[MitreTagView]: + """Batch-resolve polymorphic MITRE FKs into MitreTagViews in 3 queries + total — one per kind — regardless of how many tags or templates the + caller is rendering. + """ + tactic_ids = {t.tactic_id for t in tags if t.mitre_kind == "tactic" and t.tactic_id is not None} + technique_ids = {t.technique_id for t in tags if t.mitre_kind == "technique" and t.technique_id is not None} + sub_ids = {t.subtechnique_id for t in tags if t.mitre_kind == "subtechnique" and t.subtechnique_id is not None} + + tactic_map: dict[uuid.UUID, MitreTactic] = {} + technique_map: dict[uuid.UUID, MitreTechnique] = {} + sub_map: dict[uuid.UUID, MitreSubtechnique] = {} + if tactic_ids: + tactic_map = {row.id: row for row in s.scalars(select(MitreTactic).where(MitreTactic.id.in_(tactic_ids))).all()} + if technique_ids: + technique_map = { + row.id: row + for row in s.scalars(select(MitreTechnique).where(MitreTechnique.id.in_(technique_ids))).all() + } + if sub_ids: + sub_map = { + row.id: row + for row in s.scalars(select(MitreSubtechnique).where(MitreSubtechnique.id.in_(sub_ids))).all() + } + + views: list[MitreTagView] = [] + for tag in tags: + if tag.mitre_kind == "tactic" and tag.tactic_id in tactic_map: + row_t = tactic_map[tag.tactic_id] + views.append(MitreTagView(kind="tactic", external_id=row_t.external_id, name=row_t.name, url=row_t.url)) + elif tag.mitre_kind == "technique" and tag.technique_id in technique_map: + row_te = technique_map[tag.technique_id] + views.append(MitreTagView(kind="technique", external_id=row_te.external_id, name=row_te.name, url=row_te.url)) + elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id in sub_map: + row_sb = sub_map[tag.subtechnique_id] + views.append(MitreTagView(kind="subtechnique", external_id=row_sb.external_id, name=row_sb.name, url=row_sb.url)) + views.sort(key=lambda v: (v.kind, v.external_id)) + return views + + +def _to_views_batch(s: Session, templates: list[TestTemplate]) -> list[TestTemplateView]: + """List-level batcher: one bulk MITRE resolve for all templates' tags. + + For a list of K templates with ~T tags each, this issues 3 queries total + (one per MITRE kind) instead of 3K. We build (kind, uuid) → row maps + once, then assemble each template's view in memory. + """ + tactic_ids: set[uuid.UUID] = set() + technique_ids: set[uuid.UUID] = set() + sub_ids: set[uuid.UUID] = set() + for t in templates: + for tag in t.mitre_tags: + if tag.mitre_kind == "tactic" and tag.tactic_id is not None: + tactic_ids.add(tag.tactic_id) + elif tag.mitre_kind == "technique" and tag.technique_id is not None: + technique_ids.add(tag.technique_id) + elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None: + sub_ids.add(tag.subtechnique_id) + + tactic_map: dict[uuid.UUID, MitreTactic] = ( + {row.id: row for row in s.scalars(select(MitreTactic).where(MitreTactic.id.in_(tactic_ids))).all()} + if tactic_ids + else {} + ) + technique_map: dict[uuid.UUID, MitreTechnique] = ( + {row.id: row for row in s.scalars(select(MitreTechnique).where(MitreTechnique.id.in_(technique_ids))).all()} + if technique_ids + else {} + ) + sub_map: dict[uuid.UUID, MitreSubtechnique] = ( + {row.id: row for row in s.scalars(select(MitreSubtechnique).where(MitreSubtechnique.id.in_(sub_ids))).all()} + if sub_ids + else {} + ) + + def _views_for(tags: list[TestTemplateMitreTag]) -> list[MitreTagView]: + out: list[MitreTagView] = [] + for tag in tags: + if tag.mitre_kind == "tactic" and tag.tactic_id in tactic_map: + row_t = tactic_map[tag.tactic_id] + out.append(MitreTagView(kind="tactic", external_id=row_t.external_id, name=row_t.name, url=row_t.url)) + elif tag.mitre_kind == "technique" and tag.technique_id in technique_map: + row_te = technique_map[tag.technique_id] + out.append(MitreTagView(kind="technique", external_id=row_te.external_id, name=row_te.name, url=row_te.url)) + elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id in sub_map: + row_sb = sub_map[tag.subtechnique_id] + out.append(MitreTagView(kind="subtechnique", external_id=row_sb.external_id, name=row_sb.name, url=row_sb.url)) + out.sort(key=lambda v: (v.kind, v.external_id)) + return out + + views: list[TestTemplateView] = [] + for t in templates: + views.append( + TestTemplateView( + id=t.id, + name=t.name, + description=t.description, + objective=t.objective, + procedure_md=t.procedure_md, + prerequisites_md=t.prerequisites_md, + expected_result_red_md=t.expected_result_red_md, + expected_detection_blue_md=t.expected_detection_blue_md, + opsec_level=t.opsec_level, + tags=list(t.tags or []), + expected_iocs=list(t.expected_iocs or []), + mitre_tags=_views_for(list(t.mitre_tags)), + deleted_at=t.deleted_at, + created_at=t.created_at, + updated_at=t.updated_at, + ) + ) + return views + + +def _to_view(s: Session, t: TestTemplate) -> TestTemplateView: + tag_views = _resolve_mitre_views(s, list(t.mitre_tags)) + return TestTemplateView( + id=t.id, + name=t.name, + description=t.description, + objective=t.objective, + procedure_md=t.procedure_md, + prerequisites_md=t.prerequisites_md, + expected_result_red_md=t.expected_result_red_md, + expected_detection_blue_md=t.expected_detection_blue_md, + opsec_level=t.opsec_level, + tags=list(t.tags or []), + expected_iocs=list(t.expected_iocs or []), + mitre_tags=tag_views, + deleted_at=t.deleted_at, + created_at=t.created_at, + updated_at=t.updated_at, + ) + + +def _base_query(): + return select(TestTemplate).options(selectinload(TestTemplate.mitre_tags)) + + +def list_test_templates( + *, + q: str | None = None, + tactic: str | None = None, # external_id like "TA0006" + technique: str | None = None, + subtechnique: str | None = None, + opsec_level: str | None = None, + tag: str | None = None, + include_deleted: bool = False, + limit: int = 100, + offset: int = 0, +) -> tuple[list[TestTemplateView], int]: + with session_scope() as s: + stmt = _base_query().order_by(TestTemplate.name.asc()) + count_stmt = select(func.count()).select_from(TestTemplate) + if not include_deleted: + stmt = stmt.where(TestTemplate.deleted_at.is_(None)) + count_stmt = count_stmt.where(TestTemplate.deleted_at.is_(None)) + if q: + like = f"%{q.lower()}%" + cond = or_( + func.lower(TestTemplate.name).like(like), + func.lower(TestTemplate.description).like(like), + ) + stmt = stmt.where(cond) + count_stmt = count_stmt.where(cond) + if opsec_level: + _validate_opsec(opsec_level) + stmt = stmt.where(TestTemplate.opsec_level == opsec_level) + count_stmt = count_stmt.where(TestTemplate.opsec_level == opsec_level) + if tag: + stmt = stmt.where(TestTemplate.tags.any(tag)) + count_stmt = count_stmt.where(TestTemplate.tags.any(tag)) + + # MITRE facets: each provided facet (tactic, technique, subtechnique) is + # AND-combined — a template tagged BOTH `TA0006` AND `T1003` matches a + # query with `?tactic=TA0006&technique=T1003`, but a template tagged + # only `TA0006` does NOT. Each facet matches strictly its own column + # (no cross-column UUID collision risk). + def _facet_subquery(column, mitre_id: uuid.UUID): + return ( + select(TestTemplateMitreTag.test_template_id) + .where(column == mitre_id) + .distinct() + ) + + if tactic: + tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic)) + if tac is None: + return [], 0 + sub_q = _facet_subquery(TestTemplateMitreTag.tactic_id, tac.id) + stmt = stmt.where(TestTemplate.id.in_(sub_q)) + count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) + if technique: + tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique)) + if tech is None: + return [], 0 + sub_q = _facet_subquery(TestTemplateMitreTag.technique_id, tech.id) + stmt = stmt.where(TestTemplate.id.in_(sub_q)) + count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) + if subtechnique: + sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique)) + if sub is None: + return [], 0 + sub_q = _facet_subquery(TestTemplateMitreTag.subtechnique_id, sub.id) + stmt = stmt.where(TestTemplate.id.in_(sub_q)) + count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) + + total = s.scalar(count_stmt) or 0 + rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() + return _to_views_batch(s, list(rows)), int(total) + + +def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView: + with session_scope() as s: + t = s.get(TestTemplate, template_id) + if t is None: + raise TestTemplateNotFound() + if t.deleted_at is not None and not include_deleted: + raise TestTemplateNotFound() + return _to_view(s, t) + + +def create_test_template( + *, + name: str, + description: str | None = None, + objective: str | None = None, + procedure_md: str | None = None, + prerequisites_md: str | None = None, + expected_result_red_md: str | None = None, + expected_detection_blue_md: str | None = None, + opsec_level: str = "medium", + tags: list[str] | None = None, + expected_iocs: list[str] | None = None, + mitre_tags: list[MitreTagRef] | None = None, +) -> TestTemplateView: + name_norm = (name or "").strip() + if not name_norm: + raise ValueError("name is required") + _validate_opsec(opsec_level) + norm_tags = _normalize_string_list(tags) + norm_iocs = _normalize_string_list(expected_iocs) + with session_scope() as s: + t = TestTemplate( + name=name_norm, + description=_opt_str(description), + objective=_opt_str(objective), + procedure_md=procedure_md or None, + prerequisites_md=prerequisites_md or None, + expected_result_red_md=expected_result_red_md or None, + expected_detection_blue_md=expected_detection_blue_md or None, + opsec_level=opsec_level, + tags=norm_tags, + expected_iocs=norm_iocs, + ) + s.add(t) + s.flush() + if mitre_tags: + rows = _resolve_mitre_refs(s, mitre_tags) + for row in rows: + row.test_template_id = t.id + s.add(row) + s.flush() + s.refresh(t) + return _to_view(s, t) + + +def _opt_str(value: str | None) -> str | None: + if value is None: + return None + s = value.strip() + return s or None + + +def update_test_template( + template_id: uuid.UUID, + *, + name: str | None = None, + description: Any = _UNSET, + objective: Any = _UNSET, + procedure_md: Any = _UNSET, + prerequisites_md: Any = _UNSET, + expected_result_red_md: Any = _UNSET, + expected_detection_blue_md: Any = _UNSET, + opsec_level: str | None = None, + tags: Any = _UNSET, + expected_iocs: Any = _UNSET, + mitre_tags: Any = _UNSET, +) -> TestTemplateView: + with session_scope() as s: + t = s.get(TestTemplate, template_id) + if t is None or t.deleted_at is not None: + raise TestTemplateNotFound() + if name is not None: + n = name.strip() + if not n: + raise ValueError("name cannot be empty") + t.name = n + if description is not _UNSET: + t.description = _opt_str(description) + if objective is not _UNSET: + t.objective = _opt_str(objective) + if procedure_md is not _UNSET: + t.procedure_md = procedure_md or None + if prerequisites_md is not _UNSET: + t.prerequisites_md = prerequisites_md or None + if expected_result_red_md is not _UNSET: + t.expected_result_red_md = expected_result_red_md or None + if expected_detection_blue_md is not _UNSET: + t.expected_detection_blue_md = expected_detection_blue_md or None + if opsec_level is not None: + _validate_opsec(opsec_level) + t.opsec_level = opsec_level + if tags is not _UNSET: + t.tags = _normalize_string_list(tags) + if expected_iocs is not _UNSET: + t.expected_iocs = _normalize_string_list(expected_iocs) + if mitre_tags is not _UNSET: + for row in list(t.mitre_tags): + s.delete(row) + s.flush() + rows = _resolve_mitre_refs(s, list(mitre_tags or [])) + for row in rows: + row.test_template_id = t.id + s.add(row) + s.flush() + s.refresh(t) + return _to_view(s, t) + + +def soft_delete_test_template(template_id: uuid.UUID) -> None: + with session_scope() as s: + t = s.get(TestTemplate, template_id) + if t is None or t.deleted_at is not None: + raise TestTemplateNotFound() + t.deleted_at = datetime.now(tz=timezone.utc) diff --git a/backend/tests/test_templates.py b/backend/tests/test_templates.py new file mode 100644 index 0000000..dc623f0 --- /dev/null +++ b/backend/tests/test_templates.py @@ -0,0 +1,567 @@ +"""M5 — Template catalogue integration tests. + +Covers `test_template` and `scenario_template` CRUD + ordering + perm gating. +Relies on a minimal MITRE seed (T1059 / TA0001 / T1059.001) so the polymorphic +tag join can be exercised end-to-end. +""" + +from __future__ import annotations + +import json +import secrets + +import pytest +from sqlalchemy import text + +from app.core.install_token import regenerate_install_token +from app.main import create_app +from app.services import mitre_seed as mitre_svc + + +def _truncate_all(engine): + with engine.begin() as conn: + conn.execute( + text( + "TRUNCATE users, refresh_tokens, invitations, invitation_groups, " + "user_groups, group_permissions, permissions, settings, groups, " + "scenario_template_tests, scenario_templates, " + "test_template_mitre_tags, test_templates, " + "mitre_subtechniques, mitre_technique_tactics, mitre_techniques, " + "mitre_tactics RESTART IDENTITY CASCADE" + ) + ) + + +# Same minimal bundle as in test_mitre.py — keeps tag resolution deterministic +# without re-pulling the full enterprise STIX bundle. +_MINIMAL_BUNDLE = { + "type": "bundle", + "id": "bundle--00000000-0000-0000-0000-000000000002", + "spec_version": "2.1", + "objects": [ + { + "type": "x-mitre-tactic", + "id": "x-mitre-tactic--ta0001", + "name": "Initial Access", + "x_mitre_shortname": "initial-access", + "external_references": [ + {"source_name": "mitre-attack", "external_id": "TA0001"} + ], + }, + { + "type": "x-mitre-tactic", + "id": "x-mitre-tactic--ta0002", + "name": "Execution", + "x_mitre_shortname": "execution", + "external_references": [ + {"source_name": "mitre-attack", "external_id": "TA0002"} + ], + }, + { + "type": "attack-pattern", + "id": "attack-pattern--t1059", + "name": "Command and Scripting Interpreter", + "kill_chain_phases": [ + {"kill_chain_name": "mitre-attack", "phase_name": "execution"} + ], + "external_references": [ + {"source_name": "mitre-attack", "external_id": "T1059"} + ], + }, + { + "type": "attack-pattern", + "id": "attack-pattern--t1059-001", + "name": "PowerShell", + "x_mitre_is_subtechnique": True, + "external_references": [ + {"source_name": "mitre-attack", "external_id": "T1059.001"} + ], + }, + { + "type": "relationship", + "id": "relationship--rel1", + "relationship_type": "subtechnique-of", + "source_ref": "attack-pattern--t1059-001", + "target_ref": "attack-pattern--t1059", + }, + ], +} + + +@pytest.fixture(scope="module") +def app(db_engine_or_skip, tmp_path_factory): + _truncate_all(db_engine_or_skip) + bundle_path = tmp_path_factory.mktemp("m5") / "stix.json" + bundle_path.write_text(json.dumps(_MINIMAL_BUNDLE)) + mitre_svc.seed_mitre(source=bundle_path, expected_sha256=None) + flask_app = create_app() + flask_app.config.update(TESTING=True) + return flask_app + + +@pytest.fixture() +def client(app): + return app.test_client() + + +def _unique_email(prefix: str) -> str: + return f"{prefix}-{secrets.token_hex(4)}@metamorph.local" + + +@pytest.fixture(scope="module") +def admin(app): + token = regenerate_install_token() + email = _unique_email("admin") + password = "AdminPass1234!" + with app.test_client() as c: + r = c.post( + "/api/v1/setup", + json={"install_token": token, "email": email, "password": password}, + ) + assert r.status_code == 201, r.get_data(as_text=True) + return {"email": email, "password": password} + + +def _login(client, email: str, password: str) -> str: + r = client.post("/api/v1/auth/login", json={"email": email, "password": password}) + assert r.status_code == 200, r.get_data(as_text=True) + return r.get_json()["access_token"] + + +def _bearer(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture() +def admin_token(client, admin) -> str: + return _login(client, admin["email"], admin["password"]) + + +# === Reader fixture: an invited user with only `test_template.read` ========= + + +def _bootstrap_user_without_perms(client, admin_token: str, prefix: str) -> tuple[str, str]: + email = _unique_email(prefix) + inv = client.post( + "/api/v1/invitations", + headers=_bearer(admin_token), + json={"email_hint": email}, + ) + token = inv.get_json()["token"] + password = "ReaderPass1234!" + client.post( + f"/api/v1/invitations/accept/{token}", + json={"email": email, "password": password}, + ) + return email, _login(client, email, password) + + +# === test_template CRUD ===================================================== + + +def _make_test(client, admin_token: str, **overrides): + body = { + "name": overrides.pop("name", f"Test {secrets.token_hex(2)}"), + "description": overrides.pop("description", "auto"), + "objective": "do thing", + "procedure_md": "1. step", + "expected_result_red_md": "red expectation", + "expected_detection_blue_md": "blue expectation", + "opsec_level": overrides.pop("opsec_level", "medium"), + "tags": overrides.pop("tags", ["fast"]), + "expected_iocs": ["evil.exe"], + "mitre_tags": overrides.pop("mitre_tags", [{"kind": "technique", "external_id": "T1059"}]), + **overrides, + } + r = client.post("/api/v1/test-templates", headers=_bearer(admin_token), json=body) + assert r.status_code == 201, r.get_data(as_text=True) + return r.get_json() + + +def test_create_test_template_with_mitre_tags(client, admin_token): + body = _make_test( + client, + admin_token, + name="PowerShell exec", + mitre_tags=[ + {"kind": "tactic", "external_id": "TA0002"}, + {"kind": "technique", "external_id": "T1059"}, + {"kind": "subtechnique", "external_id": "T1059.001"}, + ], + ) + assert body["opsec_level"] == "medium" + kinds = sorted((t["kind"], t["external_id"]) for t in body["mitre_tags"]) + assert kinds == [ + ("subtechnique", "T1059.001"), + ("tactic", "TA0002"), + ("technique", "T1059"), + ] + + +def test_create_test_template_rejects_unknown_mitre(client, admin_token): + r = client.post( + "/api/v1/test-templates", + headers=_bearer(admin_token), + json={ + "name": "Bad", + "mitre_tags": [{"kind": "technique", "external_id": "T9999"}], + }, + ) + assert r.status_code == 400 + assert r.get_json()["error"] == "unknown_mitre_tag" + + +def test_create_test_template_rejects_bad_opsec(client, admin_token): + r = client.post( + "/api/v1/test-templates", + headers=_bearer(admin_token), + json={"name": "Bad", "opsec_level": "burner"}, + ) + assert r.status_code == 400 + + +def test_list_test_templates_filter_by_tactic(client, admin_token): + _make_test( + client, + admin_token, + name="filterable-1", + mitre_tags=[{"kind": "tactic", "external_id": "TA0002"}], + ) + r = client.get( + "/api/v1/test-templates?tactic=TA0002", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + body = r.get_json() + names = [it["name"] for it in body["items"]] + assert "filterable-1" in names + + +def test_list_test_templates_filter_by_opsec(client, admin_token): + _make_test(client, admin_token, name="high-opsec", opsec_level="high") + r = client.get( + "/api/v1/test-templates?opsec=high", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "high-opsec" in names + assert all(it["opsec_level"] == "high" for it in r.get_json()["items"]) + + +def test_list_test_templates_filter_by_tag(client, admin_token): + _make_test(client, admin_token, name="tagged-fast", tags=["fast", "phish"]) + r = client.get( + "/api/v1/test-templates?tag=phish", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "tagged-fast" in names + + +def test_list_test_templates_search_q(client, admin_token): + _make_test(client, admin_token, name="unique-token-azertyuiop") + r = client.get( + "/api/v1/test-templates?q=AZERTYUIOP", # case-insensitive + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "unique-token-azertyuiop" in names + + +def test_update_test_template_replaces_mitre_tags(client, admin_token): + body = _make_test( + client, + admin_token, + name="to-update", + mitre_tags=[{"kind": "tactic", "external_id": "TA0001"}], + ) + r = client.put( + f"/api/v1/test-templates/{body['id']}", + headers=_bearer(admin_token), + json={"mitre_tags": [{"kind": "technique", "external_id": "T1059"}]}, + ) + assert r.status_code == 200, r.get_data(as_text=True) + updated = r.get_json() + kinds = [(t["kind"], t["external_id"]) for t in updated["mitre_tags"]] + assert kinds == [("technique", "T1059")] + + +def test_update_test_template_partial_keeps_unset_fields(client, admin_token): + body = _make_test( + client, + admin_token, + name="partial-update", + opsec_level="low", + tags=["a", "b"], + ) + r = client.put( + f"/api/v1/test-templates/{body['id']}", + headers=_bearer(admin_token), + json={"name": "renamed"}, + ) + assert r.status_code == 200 + updated = r.get_json() + assert updated["name"] == "renamed" + assert updated["opsec_level"] == "low" # untouched + assert set(updated["tags"]) == {"a", "b"} # untouched + + +def test_soft_delete_then_list_hides_by_default(client, admin_token): + body = _make_test(client, admin_token, name="to-be-deleted") + r = client.delete( + f"/api/v1/test-templates/{body['id']}", headers=_bearer(admin_token) + ) + assert r.status_code == 200 + r2 = client.get("/api/v1/test-templates", headers=_bearer(admin_token)) + names = [it["name"] for it in r2.get_json()["items"]] + assert "to-be-deleted" not in names + # And reappears with include_deleted=true + r3 = client.get( + "/api/v1/test-templates?include_deleted=true", + headers=_bearer(admin_token), + ) + names3 = [it["name"] for it in r3.get_json()["items"]] + assert "to-be-deleted" in names3 + + +def test_read_perm_required(client, admin_token): + """A user without `test_template.read` gets 403.""" + _, eve_token = _bootstrap_user_without_perms(client, admin_token, "eve-noperm") + r = client.get("/api/v1/test-templates", headers=_bearer(eve_token)) + assert r.status_code == 403 + + +def test_write_perm_required(client, admin_token): + """A user with only `test_template.read` cannot create. + + Bootstrap path: create a dedicated group via the admin API, bind only the + `test_template.read` perm, then invite a user pre-assigned to that group. + """ + # 1. Create the read-only group + bind the single perm. + grp = client.post( + "/api/v1/groups", + headers=_bearer(admin_token), + json={"name": f"tpl-reader-{secrets.token_hex(2)}"}, + ).get_json() + r_set = client.put( + f"/api/v1/groups/{grp['id']}/permissions", + headers=_bearer(admin_token), + json={"codes": ["test_template.read"]}, + ) + assert r_set.status_code == 200, r_set.get_data(as_text=True) + + # 2. Invite a user already attached to that group. + email = _unique_email("alice-readonly") + password = "ReaderPass1234!" + inv = client.post( + "/api/v1/invitations", + headers=_bearer(admin_token), + json={"email_hint": email, "group_ids": [grp["id"]]}, + ).get_json() + client.post( + f"/api/v1/invitations/accept/{inv['token']}", + json={"email": email, "password": password}, + ) + token = _login(client, email, password) + + r = client.get("/api/v1/test-templates", headers=_bearer(token)) + assert r.status_code == 200, r.get_data(as_text=True) + r2 = client.post( + "/api/v1/test-templates", headers=_bearer(token), json={"name": "X"} + ) + assert r2.status_code == 403 + + +# === scenario_template CRUD ================================================= + + +def test_create_scenario_with_ordered_tests(client, admin_token): + a = _make_test(client, admin_token, name="scn-a") + b = _make_test(client, admin_token, name="scn-b") + c = _make_test(client, admin_token, name="scn-c") + r = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={ + "name": "phishing-flow", + "description": "click → exec → persist", + "test_template_ids": [a["id"], b["id"], c["id"]], + }, + ) + assert r.status_code == 201, r.get_data(as_text=True) + body = r.get_json() + assert body["tests_count"] == 3 + assert [t["position"] for t in body["tests"]] == [0, 1, 2] + assert [t["test_template_name"] for t in body["tests"]] == ["scn-a", "scn-b", "scn-c"] + + +def test_reorder_scenario_tests(client, admin_token): + a = _make_test(client, admin_token, name="reord-a") + b = _make_test(client, admin_token, name="reord-b") + c = _make_test(client, admin_token, name="reord-c") + created = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={ + "name": "reorder-me", + "test_template_ids": [a["id"], b["id"], c["id"]], + }, + ).get_json() + # Reverse order. + r = client.put( + f"/api/v1/scenario-templates/{created['id']}/tests", + headers=_bearer(admin_token), + json={"test_template_ids": [c["id"], b["id"], a["id"]]}, + ) + assert r.status_code == 200 + after = r.get_json() + assert [t["test_template_name"] for t in after["tests"]] == ["reord-c", "reord-b", "reord-a"] + # Re-reading via GET yields the same order — confirms persistence. + fresh = client.get( + f"/api/v1/scenario-templates/{created['id']}", headers=_bearer(admin_token) + ).get_json() + assert [t["test_template_name"] for t in fresh["tests"]] == ["reord-c", "reord-b", "reord-a"] + + +def test_scenario_rejects_unknown_test_id(client, admin_token): + r = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={ + "name": "bad", + "test_template_ids": ["00000000-0000-0000-0000-000000000000"], + }, + ) + assert r.status_code == 400 + assert r.get_json()["error"] == "unknown_test_template" + + +def test_scenario_rejects_soft_deleted_test_on_create(client, admin_token): + a = _make_test(client, admin_token, name="will-be-deleted") + client.delete(f"/api/v1/test-templates/{a['id']}", headers=_bearer(admin_token)) + r = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={"name": "linked", "test_template_ids": [a["id"]]}, + ) + assert r.status_code == 400 + assert r.get_json()["error"] == "unknown_test_template" + + +def test_scenario_surfaces_soft_deleted_test_after_link(client, admin_token): + """Once linked, a test can be soft-deleted without breaking the scenario — + the join row stays and the API flags the test as deleted.""" + a = _make_test(client, admin_token, name="linked-then-deleted") + sc = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={"name": "survives", "test_template_ids": [a["id"]]}, + ).get_json() + client.delete(f"/api/v1/test-templates/{a['id']}", headers=_bearer(admin_token)) + fresh = client.get( + f"/api/v1/scenario-templates/{sc['id']}", headers=_bearer(admin_token) + ).get_json() + assert fresh["tests"][0]["test_template_deleted"] is True + + +def test_scenario_soft_delete(client, admin_token): + sc = client.post( + "/api/v1/scenario-templates", + headers=_bearer(admin_token), + json={"name": "doomed-scn"}, + ).get_json() + r = client.delete( + f"/api/v1/scenario-templates/{sc['id']}", headers=_bearer(admin_token) + ) + assert r.status_code == 200 + names = [ + it["name"] + for it in client.get( + "/api/v1/scenario-templates", headers=_bearer(admin_token) + ).get_json()["items"] + ] + assert "doomed-scn" not in names + + +def test_scenario_perm_required(client, admin_token): + _, eve_token = _bootstrap_user_without_perms(client, admin_token, "scn-eve") + r = client.get("/api/v1/scenario-templates", headers=_bearer(eve_token)) + assert r.status_code == 403 + + +# === Post-review fixes ====================================================== + + +def test_list_filter_combines_facets_with_and_semantics(client, admin_token): + """A template tagged only `TA0002` is NOT in `?tactic=TA0002&technique=T1059`. + + Pre-fix the OR-combined query would return it. AND-combined semantics + (one IN subquery per facet) restrict the set to templates matching ALL + requested facets. + """ + a = _make_test( + client, + admin_token, + name="and-tactic-only", + mitre_tags=[{"kind": "tactic", "external_id": "TA0002"}], + ) + b = _make_test( + client, + admin_token, + name="and-both-tags", + mitre_tags=[ + {"kind": "tactic", "external_id": "TA0002"}, + {"kind": "technique", "external_id": "T1059"}, + ], + ) + r = client.get( + "/api/v1/test-templates?tactic=TA0002&technique=T1059", + headers=_bearer(admin_token), + ) + assert r.status_code == 200 + names = [it["name"] for it in r.get_json()["items"]] + assert "and-both-tags" in names + assert "and-tactic-only" not in names + _ = a, b # silence unused vars from linter + + +def test_create_test_template_rejects_extra_fields(client, admin_token): + """`model_config = {"extra": "forbid"}` — unknown fields must 400.""" + r = client.post( + "/api/v1/test-templates", + headers=_bearer(admin_token), + json={"name": "extra-test", "rogue_field": "smuggled"}, + ) + assert r.status_code == 400 + + +def test_update_test_template_explicit_empty_mitre_clears(client, admin_token): + """`PUT { mitre_tags: [] }` is an explicit clear, not a no-op.""" + body = _make_test( + client, + admin_token, + name="clear-tags", + mitre_tags=[{"kind": "technique", "external_id": "T1059"}], + ) + assert len(body["mitre_tags"]) == 1 + r = client.put( + f"/api/v1/test-templates/{body['id']}", + headers=_bearer(admin_token), + json={"mitre_tags": []}, + ) + assert r.status_code == 200 + assert r.get_json()["mitre_tags"] == [] + + +def test_tag_item_length_capped_at_64(client, admin_token): + """Individual `tags` items must be ≤ 64 chars at the wire layer.""" + long_tag = "x" * 65 + r = client.post( + "/api/v1/test-templates", + headers=_bearer(admin_token), + json={"name": "long-tag", "tags": [long_tag]}, + ) + assert r.status_code == 400 diff --git a/e2e/tests/m5-templates.spec.ts b/e2e/tests/m5-templates.spec.ts new file mode 100644 index 0000000..8a58963 --- /dev/null +++ b/e2e/tests/m5-templates.spec.ts @@ -0,0 +1,253 @@ +import { expect, test, type APIRequestContext, type Page } from '@playwright/test'; + +/** + * M5 — Test + Scenario template catalogue. + * + * Verifies CRUD on /test-templates and /scenario-templates plus the admin SPA + * pages. We do NOT seed the full MITRE bundle here — M4 already covers that + * suite. This spec only needs ONE technique resolvable from a STIX-like + * shape (we ride on the same `/diag/reset` then re-seed MITRE so tag refs + * resolve). + */ + +const ADMIN_EMAIL = `admin-${crypto.randomUUID().slice(0, 8)}@metamorph.local`; +const ADMIN_PASSWORD = 'AdminPass1234!'; + +async function resetAndMintToken(request: APIRequestContext): Promise { + const r = await request.post('/api/v1/diag/reset'); + expect(r.status()).toBe(200); + return (await r.json()).install_token as string; +} + +async function loginAndGetAccess( + request: APIRequestContext, + email: string, + password: string, +): Promise { + const r = await request.post('/api/v1/auth/login', { data: { email, password } }); + expect(r.status()).toBe(200); + return (await r.json()).access_token as string; +} + +async function loginViaSpa(page: Page, email: string, password: string) { + await page.goto('/login'); + await page.getByLabel(/email/i).fill(email); + await page.getByLabel(/password/i).fill(password); + await page.getByRole('button', { name: /sign in/i }).click(); + await expect(page.getByTestId('me-email')).toHaveText(email); +} + +test.describe.configure({ mode: 'serial' }); + +test.describe('M5 — Template catalogue', () => { + test.beforeAll(async ({ request }) => { + const installToken = await resetAndMintToken(request); + const setup = await request.post('/api/v1/setup', { + data: { install_token: installToken, email: ADMIN_EMAIL, password: ADMIN_PASSWORD }, + }); + expect(setup.status()).toBe(201); + // MITRE re-sync — picker + tag refs rely on the canonical bundle. + const access = await loginAndGetAccess(request, ADMIN_EMAIL, ADMIN_PASSWORD); + const sync = await request.post('/api/v1/mitre/sync', { + headers: { Authorization: `Bearer ${access}` }, + }); + expect(sync.status()).toBe(200); + }); + + test.afterAll(async ({ request }) => { + // Restore the stable admin (cf. memory feedback_metamorph_test_admin): + // any wipe should leave admin@metamorph.local / AdminPass1234! usable. + const installToken = await resetAndMintToken(request); + await request.post('/api/v1/setup', { + data: { + install_token: installToken, + email: 'admin@metamorph.local', + password: 'AdminPass1234!', + }, + }); + // Re-seed MITRE so subsequent manual sessions don't see an empty matrix. + const access = await loginAndGetAccess(request, 'admin@metamorph.local', 'AdminPass1234!'); + await request.post('/api/v1/mitre/sync', { + headers: { Authorization: `Bearer ${access}` }, + }); + }); + + // === API smoke ============================================================ + + test('CRUD test-templates via API', async ({ request }) => { + const access = await loginAndGetAccess(request, ADMIN_EMAIL, ADMIN_PASSWORD); + const auth = { Authorization: `Bearer ${access}` }; + + // Create + const r1 = await request.post('/api/v1/test-templates', { + headers: auth, + data: { + name: 'phish-link', + description: 'send a phishing email with tracked link', + objective: 'land a click', + procedure_md: '1. craft mail\n2. send\n3. await click', + opsec_level: 'low', + tags: ['phish', 'initial-access'], + expected_iocs: ['phish@example.com'], + mitre_tags: [ + { kind: 'tactic', external_id: 'TA0001' }, + { kind: 'technique', external_id: 'T1566' }, + ], + }, + }); + expect(r1.status(), await r1.text()).toBe(201); + const created = await r1.json(); + expect(created.name).toBe('phish-link'); + expect(created.mitre_tags.length).toBe(2); + expect(created.tags).toContain('phish'); + + // Update — partial: change opsec only + const r2 = await request.put(`/api/v1/test-templates/${created.id}`, { + headers: auth, + data: { opsec_level: 'high' }, + }); + expect(r2.status()).toBe(200); + const updated = await r2.json(); + expect(updated.opsec_level).toBe('high'); + expect(updated.name).toBe('phish-link'); // untouched + + // List + filter by tactic + const r3 = await request.get('/api/v1/test-templates?tactic=TA0001', { + headers: auth, + }); + expect(r3.status()).toBe(200); + const list = await r3.json(); + expect(list.items.map((it: { name: string }) => it.name)).toContain('phish-link'); + + // Reject unknown MITRE + const r4 = await request.post('/api/v1/test-templates', { + headers: auth, + data: { + name: 'bad', + mitre_tags: [{ kind: 'technique', external_id: 'T9999' }], + }, + }); + expect(r4.status()).toBe(400); + expect((await r4.json()).error).toBe('unknown_mitre_tag'); + + // Soft-delete + const r5 = await request.delete(`/api/v1/test-templates/${created.id}`, { + headers: auth, + }); + expect(r5.status()).toBe(200); + const r6 = await request.get('/api/v1/test-templates', { headers: auth }); + expect( + (await r6.json()).items.map((it: { name: string }) => it.name), + ).not.toContain('phish-link'); + }); + + test('Scenario template: create + reorder + soft-delete', async ({ request }) => { + const access = await loginAndGetAccess(request, ADMIN_EMAIL, ADMIN_PASSWORD); + const auth = { Authorization: `Bearer ${access}` }; + + async function mkTest(name: string): Promise { + const r = await request.post('/api/v1/test-templates', { + headers: auth, + data: { name }, + }); + expect(r.status()).toBe(201); + return (await r.json()).id as string; + } + + const a = await mkTest('scn-step-a'); + const b = await mkTest('scn-step-b'); + const c = await mkTest('scn-step-c'); + + // Create with [a, b, c] + const r1 = await request.post('/api/v1/scenario-templates', { + headers: auth, + data: { name: 'ordered-scenario', test_template_ids: [a, b, c] }, + }); + expect(r1.status()).toBe(201); + const sc = await r1.json(); + expect(sc.tests.map((t: { test_template_name: string }) => t.test_template_name)).toEqual([ + 'scn-step-a', + 'scn-step-b', + 'scn-step-c', + ]); + + // Reorder → [c, a, b] + const r2 = await request.put(`/api/v1/scenario-templates/${sc.id}/tests`, { + headers: auth, + data: { test_template_ids: [c, a, b] }, + }); + expect(r2.status()).toBe(200); + const after = await r2.json(); + expect(after.tests.map((t: { test_template_name: string }) => t.test_template_name)).toEqual([ + 'scn-step-c', + 'scn-step-a', + 'scn-step-b', + ]); + + // Soft-delete the scenario. + const r3 = await request.delete(`/api/v1/scenario-templates/${sc.id}`, { headers: auth }); + expect(r3.status()).toBe(200); + const list = await (await request.get('/api/v1/scenario-templates', { headers: auth })).json(); + expect(list.items.map((it: { name: string }) => it.name)).not.toContain('ordered-scenario'); + }); + + // === SPA smoke ============================================================ + + test('SPA — admin sees the test catalogue and can filter', async ({ page, request }) => { + // Seed two tests up front via the API — exercise the SPA list + filter + // pipeline without fighting the heavy create-modal (covered by API tests). + const access = await loginAndGetAccess(request, ADMIN_EMAIL, ADMIN_PASSWORD); + const auth = { Authorization: `Bearer ${access}` }; + await request.post('/api/v1/test-templates', { + headers: auth, + data: { name: 'spa-list-fast', opsec_level: 'low', tags: ['fast'] }, + }); + await request.post('/api/v1/test-templates', { + headers: auth, + data: { name: 'spa-list-slow', opsec_level: 'high' }, + }); + + await loginViaSpa(page, ADMIN_EMAIL, ADMIN_PASSWORD); + await page.goto('/admin/tests'); + + await expect(page.getByText('spa-list-fast')).toBeVisible(); + await expect(page.getByText('spa-list-slow')).toBeVisible(); + + await page.getByTestId('filter-opsec').selectOption('high'); + await expect(page.getByText('spa-list-slow')).toBeVisible(); + await expect(page.getByText('spa-list-fast')).toBeHidden(); + }); + + test('SPA — scenario list shows ordered tests with their position', async ({ page, request }) => { + // Seed a 3-test scenario via the API; the SPA must render the order as + // saved. Pointer-event drag is flaky in CI, and the API-level reorder + // test already covers the persistence pipeline. + const access = await loginAndGetAccess(request, ADMIN_EMAIL, ADMIN_PASSWORD); + const auth = { Authorization: `Bearer ${access}` }; + const ids: string[] = []; + for (const name of ['drag-1', 'drag-2', 'drag-3']) { + const r = await request.post('/api/v1/test-templates', { + headers: auth, + data: { name }, + }); + ids.push((await r.json()).id); + } + const scResp = await request.post('/api/v1/scenario-templates', { + headers: auth, + data: { + name: 'spa-rendered-scenario', + test_template_ids: [ids[2], ids[0], ids[1]], + }, + }); + const scId = (await scResp.json()).id; + + await loginViaSpa(page, ADMIN_EMAIL, ADMIN_PASSWORD); + await page.goto('/admin/scenarios'); + + const card = page.locator(`[data-testid="scenario-row-${scId}"]`); + await expect(card).toBeVisible(); + await expect(card.getByText('1. drag-3')).toBeVisible(); + await expect(card.getByText('2. drag-1')).toBeVisible(); + await expect(card.getByText('3. drag-2')).toBeVisible(); + }); +}); diff --git a/frontend/package.json b/frontend/package.json index 1f7e91f..2da9cbc 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -13,6 +13,9 @@ "format:check": "prettier --check \"src/**/*.{ts,tsx,css,json,html}\"" }, "dependencies": { + "@dnd-kit/core": "^6.1.0", + "@dnd-kit/sortable": "^8.0.0", + "@dnd-kit/utilities": "^3.2.2", "@fontsource/ibm-plex-sans": "^5.0.20", "@fontsource/jetbrains-mono": "^5.0.20", "@tanstack/react-query": "^5.51.0", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index db48353..44f8731 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -6,6 +6,8 @@ import { RequireAdmin } from '@/components/RequireAdmin'; import { RequireAuth } from '@/components/RequireAuth'; import { AdminGroupsPage } from '@/pages/AdminGroupsPage'; import { AdminInvitationsPage } from '@/pages/AdminInvitationsPage'; +import { AdminScenariosPage } from '@/pages/AdminScenariosPage'; +import { AdminTestsPage } from '@/pages/AdminTestsPage'; import { AdminUsersPage } from '@/pages/AdminUsersPage'; import { HomePage } from '@/pages/HomePage'; import { MitrePage } from '@/pages/MitrePage'; @@ -82,6 +84,22 @@ function App() { } /> + + + + } + /> + + + + } + /> } /> diff --git a/frontend/src/components/Layout.tsx b/frontend/src/components/Layout.tsx index 0eab999..a9b269b 100644 --- a/frontend/src/components/Layout.tsx +++ b/frontend/src/components/Layout.tsx @@ -42,6 +42,8 @@ export function Layout() { {navItem('/admin/users', 'Users')} {navItem('/admin/groups', 'Groups')} {navItem('/admin/invitations', 'Invitations')} + {navItem('/admin/tests', 'Tests')} + {navItem('/admin/scenarios', 'Scenarios')} )} @@ -69,7 +71,7 @@ export function Layout() {
- metamorph · M0 bootstrap · M1 db schema · M2 auth · M3 rbac · M4 mitre · design system from tasks/design.md + metamorph · M0 bootstrap · M1 db schema · M2 auth · M3 rbac · M4 mitre · M5 templates · design system from tasks/design.md
diff --git a/frontend/src/components/MarkdownField.tsx b/frontend/src/components/MarkdownField.tsx new file mode 100644 index 0000000..635793a --- /dev/null +++ b/frontend/src/components/MarkdownField.tsx @@ -0,0 +1,45 @@ +import { useId, type TextareaHTMLAttributes } from 'react'; + +import { cn } from '@/lib/cn'; + +interface MarkdownFieldProps extends Omit, 'value' | 'onChange'> { + label: string; + value: string; + onChange: (next: string) => void; + rows?: number; + hint?: string; +} + +/** + * Markdown-content textarea. We deliberately keep it textarea-only (no fancy + * WYSIWYG editor) — markdown lives well in plain text and the saved blob is + * rendered to HTML at display time (M6/M7 mission pages). The label exposes + * "markdown" so the user knows the field accepts MD syntax. + */ +export function MarkdownField({ label, value, onChange, rows = 6, hint, id, className, ...rest }: MarkdownFieldProps) { + const fallbackId = useId(); + const inputId = id ?? fallbackId; + return ( +
+ +