From b5ea2929de94d1cef8bf2c67826cf88a5d6f41f5 Mon Sep 17 00:00:00 2001 From: Knacky Date: Wed, 27 May 2026 03:56:02 +0200 Subject: [PATCH 1/8] =?UTF-8?q?feat(backend):=20sprint=203=20=E2=80=94=20m?= =?UTF-8?q?ulti-technique=20simulations=20+=20MITRE=20matrix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Simulation model: replace mitre_technique_id/name scalars with techniques JSON column [{id, name}] - Alembic migration 0003: add techniques, backfill from scalars, drop old columns (reversible) - MITRE service: add get_tactics(), lookup_name(), get_matrix() with canonical tactic order and sub-technique nesting - serializer: enrich techniques with tactics from service at serialize time (graceful empty tactics if bundle outdated) - simulation_workflow: PATCH now accepts technique_ids list, validates against bundle, deduplicates preserving order, auto-transitions on non-empty list - simulations API: add GET /api/mitre/matrix endpoint (503 if bundle absent) - test_mitre.py: updated _reset_mitre fixture, added T1059.006 sub-technique, 14 new tests for get_tactics/lookup_name/get_matrix/matrix endpoint - test_simulations_techniques.py: 20 new tests covering AC-13.1 to AC-13.5 (create, PATCH, dedup, auto-transition, SOC blocked, migration backfill logic) Total: 161 tests passing. ruff clean. mypy: no new errors. Co-Authored-By: Claude Sonnet 4.6 --- backend/app/api/simulations.py | 13 +- backend/app/models/simulation.py | 3 +- backend/app/serializers.py | 13 +- backend/app/services/mitre.py | 114 +++++- backend/app/services/simulation_workflow.py | 63 +++- .../0003_simulation_techniques_array.py | 82 +++++ backend/tests/test_mitre.py | 132 ++++++- backend/tests/test_simulations_techniques.py | 347 ++++++++++++++++++ 8 files changed, 737 insertions(+), 30 deletions(-) create mode 100644 backend/migrations/versions/0003_simulation_techniques_array.py create mode 100644 backend/tests/test_simulations_techniques.py diff --git a/backend/app/api/simulations.py b/backend/app/api/simulations.py index c6fd5cb..d5a2f2a 100644 --- a/backend/app/api/simulations.py +++ b/backend/app/api/simulations.py @@ -121,7 +121,7 @@ def transition_simulation(sid: int): # --------------------------------------------------------------------------- -# MITRE autocomplete +# MITRE autocomplete + matrix # --------------------------------------------------------------------------- @@ -136,3 +136,14 @@ def mitre_techniques(): q = request.args.get("q", "").strip() results = mitre_svc.search(q) return jsonify(results), 200 + + +@simulations_bp.get("/api/mitre/matrix") +@login_required +def mitre_matrix(): + from backend.app.services import mitre as mitre_svc + + if not mitre_svc.mitre_loaded: + return jsonify({"error": "mitre bundle not loaded"}), 503 + + return jsonify(mitre_svc.get_matrix()), 200 diff --git a/backend/app/models/simulation.py b/backend/app/models/simulation.py index 17b0663..74d99df 100644 --- a/backend/app/models/simulation.py +++ b/backend/app/models/simulation.py @@ -25,8 +25,7 @@ class Simulation(db.Model): # type: ignore[name-defined] index=True, ) name = db.Column(db.String(255), nullable=False) - mitre_technique_id = db.Column(db.String(32), nullable=True) - mitre_technique_name = db.Column(db.String(255), nullable=True) + techniques = db.Column(db.JSON, nullable=False, default=list) description = db.Column(db.Text, nullable=True) commands = db.Column(db.Text, nullable=True) prerequisites = db.Column(db.Text, nullable=True) diff --git a/backend/app/serializers.py b/backend/app/serializers.py index 7985614..d54e9cc 100644 --- a/backend/app/serializers.py +++ b/backend/app/serializers.py @@ -20,13 +20,22 @@ def serialize_user_brief(user: User) -> dict[str, Any]: return {"id": user.id, "username": user.username} +def _enrich_techniques(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Attach tactics to each {id, name} snapshot from the MITRE service.""" + from backend.app.services import mitre as mitre_svc + + return [ + {"id": t["id"], "name": t["name"], "tactics": mitre_svc.get_tactics(t["id"])} + for t in (raw or []) + ] + + def serialize_simulation(simulation: Simulation) -> dict[str, Any]: return { "id": simulation.id, "engagement_id": simulation.engagement_id, "name": simulation.name, - "mitre_technique_id": simulation.mitre_technique_id, - "mitre_technique_name": simulation.mitre_technique_name, + "techniques": _enrich_techniques(simulation.techniques or []), "description": simulation.description, "commands": simulation.commands, "prerequisites": simulation.prerequisites, diff --git a/backend/app/services/mitre.py b/backend/app/services/mitre.py index 1c52498..7e2c983 100644 --- a/backend/app/services/mitre.py +++ b/backend/app/services/mitre.py @@ -8,11 +8,30 @@ from typing import Any logger = logging.getLogger(__name__) -# Absolute path to the committed bundle. _BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json" +# Canonical Enterprise tactic order (12 tactics). +_TACTIC_ORDER = [ + "initial-access", + "execution", + "persistence", + "privilege-escalation", + "defense-evasion", + "credential-access", + "discovery", + "lateral-movement", + "collection", + "command-and-control", + "exfiltration", + "impact", +] + mitre_loaded: bool = False _index: list[dict[str, Any]] = [] +_tactics_by_technique: dict[str, list[str]] = {} +_name_by_id: dict[str, str] = {} +# matrix: list of tactic dicts (built once at load time) +_matrix: list[dict[str, Any]] = [] def _extract_tactics(obj: dict[str, Any]) -> list[str]: @@ -20,7 +39,7 @@ def _extract_tactics(obj: dict[str, Any]) -> list[str]: return [ p["phase_name"] for p in phases - if isinstance(p, dict) and "phase_name" in p + if isinstance(p, dict) and "phase_name" in p and p.get("kill_chain_name") == "mitre-attack" ] @@ -31,9 +50,65 @@ def _get_external_id(obj: dict[str, Any]) -> str | None: return None +def _is_subtechnique(tech_id: str) -> bool: + return "." in tech_id + + +def _parent_id(sub_id: str) -> str: + return sub_id.split(".")[0] + + +def _build_matrix(entries: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Build the tactic → techniques → subtechniques tree.""" + # Group top-level techniques by tactic. + tactic_techs: dict[str, list[dict[str, Any]]] = {t: [] for t in _TACTIC_ORDER} + + for entry in entries: + if _is_subtechnique(entry["id"]): + continue + for tactic in entry["tactics"]: + if tactic in tactic_techs: + tactic_techs[tactic].append(entry) + + # Attach sub-techniques to their parents. + parent_subs: dict[str, list[dict[str, Any]]] = {} + for entry in entries: + if not _is_subtechnique(entry["id"]): + continue + pid = _parent_id(entry["id"]) + parent_subs.setdefault(pid, []).append({"id": entry["id"], "name": entry["name"]}) + + # Sort subs alphabetically by name. + for subs in parent_subs.values(): + subs.sort(key=lambda x: x["name"]) + + matrix: list[dict[str, Any]] = [] + for tactic_id in _TACTIC_ORDER: + techs = tactic_techs.get(tactic_id, []) + # Sort techniques alphabetically. + techs_sorted = sorted(techs, key=lambda x: x["name"]) + tactic_name = tactic_id.replace("-", " ").title() + matrix.append( + { + "tactic_id": tactic_id, + "tactic_name": tactic_name, + "techniques": [ + { + "id": t["id"], + "name": t["name"], + "subtechniques": parent_subs.get(t["id"], []), + } + for t in techs_sorted + ], + } + ) + + return matrix + + def load_bundle(path: Path | None = None) -> None: """Load the MITRE bundle into memory. Called once at app boot.""" - global mitre_loaded, _index + global mitre_loaded, _index, _tactics_by_technique, _name_by_id, _matrix bundle_path = path or _BUNDLE_PATH try: @@ -49,6 +124,9 @@ def load_bundle(path: Path | None = None) -> None: return entries: list[dict[str, Any]] = [] + tactics_map: dict[str, list[str]] = {} + name_map: dict[str, str] = {} + for obj in data.get("objects") or []: if not isinstance(obj, dict): continue @@ -59,19 +137,35 @@ def load_bundle(path: Path | None = None) -> None: ext_id = _get_external_id(obj) if not ext_id: continue - entries.append( - { - "id": ext_id, - "name": obj.get("name", ""), - "tactics": _extract_tactics(obj), - } - ) + tactics = _extract_tactics(obj) + name = obj.get("name", "") + entries.append({"id": ext_id, "name": name, "tactics": tactics}) + tactics_map[ext_id] = tactics + name_map[ext_id] = name _index = entries + _tactics_by_technique = tactics_map + _name_by_id = name_map + _matrix = _build_matrix(entries) mitre_loaded = True logger.info("MITRE bundle loaded: %d techniques", len(_index)) +def get_tactics(technique_id: str) -> list[str]: + """Return tactic list for a technique id; empty list if unknown.""" + return _tactics_by_technique.get(technique_id, []) + + +def lookup_name(technique_id: str) -> str | None: + """Return the name for a technique id, or None if not in the bundle.""" + return _name_by_id.get(technique_id) + + +def get_matrix() -> list[dict[str, Any]]: + """Return the full tactic → techniques → subtechniques tree.""" + return _matrix + + def search(query: str, limit: int = 20) -> list[dict[str, Any]]: """Return up to `limit` techniques matching `query`. diff --git a/backend/app/services/simulation_workflow.py b/backend/app/services/simulation_workflow.py index 535ccfd..2c8ddd7 100644 --- a/backend/app/services/simulation_workflow.py +++ b/backend/app/services/simulation_workflow.py @@ -10,11 +10,10 @@ from backend.app.extensions import db from backend.app.models import User from backend.app.models.simulation import Simulation, SimulationStatus +# Fields only admin/redteam may write (excluding technique_ids which is handled separately). REDTEAM_FIELDS = frozenset( { "name", - "mitre_technique_id", - "mitre_technique_name", "description", "commands", "prerequisites", @@ -25,8 +24,6 @@ REDTEAM_FIELDS = frozenset( SOC_FIELDS = frozenset({"log_source", "logs", "soc_comment", "incident_number"}) -# Transitions allowed via POST /transition endpoint (manual only). -# auto pending→in_progress is handled in apply_patch, not here. _ALLOWED_TRANSITIONS: dict[str, dict[str, set[str]]] = { "review_required": { "from": {"pending", "in_progress"}, @@ -48,6 +45,27 @@ def _is_non_empty(value: Any) -> bool: return not (isinstance(value, list) and len(value) == 0) +def _resolve_technique_ids( + technique_ids: list[str], +) -> tuple[list[dict[str, str]] | None, tuple[Any, int] | None]: + """Validate and resolve technique IDs to [{id, name}] snapshots. + + Returns (resolved_list, None) on success or (None, error_tuple) on failure. + Deduplicates while preserving order. + """ + from backend.app.services import mitre as mitre_svc + + # Dedup, preserve order. + seen: dict[str, None] = dict.fromkeys(technique_ids) + resolved: list[dict[str, str]] = [] + for tid in seen: + name = mitre_svc.lookup_name(tid) + if name is None: + return None, (jsonify({"error": f"unknown technique id: {tid}"}), 400) + resolved.append({"id": tid, "name": name}) + return resolved, None + + def apply_patch( simulation: Simulation, payload: dict[str, Any], user: User ) -> tuple[Any, int] | None: @@ -59,15 +77,14 @@ def apply_patch( role = user.role.value if role == "soc": - # SOC can only patch when status allows it. if simulation.status not in ( SimulationStatus.REVIEW_REQUIRED, SimulationStatus.DONE, ): return jsonify({"error": "simulation not ready for SOC review"}), 403 - # SOC must not send redteam fields. - redteam_keys_in_payload = REDTEAM_FIELDS & payload.keys() + # SOC must not send redteam fields or technique_ids. + redteam_keys_in_payload = (REDTEAM_FIELDS | {"technique_ids"}) & payload.keys() if redteam_keys_in_payload: return jsonify({"error": "soc cannot edit redteam fields"}), 403 @@ -76,10 +93,10 @@ def apply_patch( setattr(simulation, field, payload[field]) else: - # admin / redteam: apply all fields present. + # admin / redteam path. redteam_keys_present = REDTEAM_FIELDS & payload.keys() - # Validate executed_at before any writes so a bad value causes no partial mutation. + # Validate executed_at upfront before any writes. executed_at_value: datetime | None = None if "executed_at" in redteam_keys_present: val = payload["executed_at"] @@ -91,21 +108,39 @@ def apply_patch( except ValueError: return jsonify({"error": "invalid executed_at"}), 400 + # Validate and resolve technique_ids upfront. + resolved_techniques: list[dict[str, str]] | None = None + if "technique_ids" in payload: + raw_ids = payload["technique_ids"] + if not isinstance(raw_ids, list): + return jsonify({"error": "technique_ids must be a list"}), 400 + resolved_techniques, err = _resolve_technique_ids(raw_ids) + if err is not None: + return err + + # Apply scalar redteam fields. for field in redteam_keys_present: if field == "executed_at": simulation.executed_at = executed_at_value else: setattr(simulation, field, payload[field]) + # Apply resolved techniques. + if resolved_techniques is not None: + simulation.techniques = resolved_techniques + + # Apply SOC fields (admin/redteam may also write them). for field in SOC_FIELDS: if field in payload: setattr(simulation, field, payload[field]) - # Auto-transition pending → in_progress: at least one redteam field with - # a non-empty value in the *incoming payload*. - if simulation.status == SimulationStatus.PENDING and any( - _is_non_empty(payload[k]) for k in redteam_keys_present - ): + # Auto-transition pending → in_progress. + # Triggers when any redteam scalar has a non-empty value, OR technique_ids is non-empty. + auto_trigger = any(_is_non_empty(payload[k]) for k in redteam_keys_present) + if not auto_trigger and "technique_ids" in payload: + auto_trigger = len(payload["technique_ids"]) > 0 + + if simulation.status == SimulationStatus.PENDING and auto_trigger: simulation.status = SimulationStatus.IN_PROGRESS simulation.updated_at = datetime.now(UTC) diff --git a/backend/migrations/versions/0003_simulation_techniques_array.py b/backend/migrations/versions/0003_simulation_techniques_array.py new file mode 100644 index 0000000..84fa165 --- /dev/null +++ b/backend/migrations/versions/0003_simulation_techniques_array.py @@ -0,0 +1,82 @@ +"""replace scalar MITRE columns with techniques JSON array + +Revision ID: 0003 +Revises: 0002 +Create Date: 2026-05-27 00:00:00.000000 +""" +import json + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import column, table, text + + +revision = "0003" +down_revision = "0002" +branch_labels = None +depends_on = None + +# Lightweight table proxies for data migration (no ORM import). +_sims = table( + "simulations", + column("id", sa.Integer), + column("mitre_technique_id", sa.String), + column("mitre_technique_name", sa.String), + column("techniques", sa.Text), +) + + +def upgrade(): + bind = op.get_bind() + + # 1. Add techniques column (nullable while we backfill). + op.add_column("simulations", sa.Column("techniques", sa.Text(), nullable=True)) + + # 2. Backfill: scalar → JSON array. + rows = bind.execute( + text("SELECT id, mitre_technique_id, mitre_technique_name FROM simulations") + ).fetchall() + for row in rows: + if row[1]: # mitre_technique_id is not null + val = json.dumps([{"id": row[1], "name": row[2] or ""}]) + else: + val = "[]" + bind.execute( + text("UPDATE simulations SET techniques = :v WHERE id = :id"), + {"v": val, "id": row[0]}, + ) + + # 3. Make NOT NULL now that every row has a value. + # SQLite doesn't support ALTER COLUMN, so we skip the nullable constraint + # change at DDL level — the application model enforces it. + + # 4. Drop old scalar columns. + with op.batch_alter_table("simulations") as batch_op: + batch_op.drop_column("mitre_technique_id") + batch_op.drop_column("mitre_technique_name") + + +def downgrade(): + bind = op.get_bind() + + # 1. Re-add scalar columns. + with op.batch_alter_table("simulations") as batch_op: + batch_op.add_column(sa.Column("mitre_technique_id", sa.String(length=32), nullable=True)) + batch_op.add_column(sa.Column("mitre_technique_name", sa.String(length=255), nullable=True)) + + # 2. Back-fill: take first element of techniques array. + rows = bind.execute(text("SELECT id, techniques FROM simulations")).fetchall() + for row in rows: + techniques = json.loads(row[1] or "[]") + if techniques: + first = techniques[0] + bind.execute( + text( + "UPDATE simulations SET mitre_technique_id = :tid, mitre_technique_name = :tname WHERE id = :id" + ), + {"tid": first.get("id"), "tname": first.get("name"), "id": row[0]}, + ) + + # 3. Drop techniques column. + with op.batch_alter_table("simulations") as batch_op: + batch_op.drop_column("techniques") diff --git a/backend/tests/test_mitre.py b/backend/tests/test_mitre.py index a31c3cd..2057571 100644 --- a/backend/tests/test_mitre.py +++ b/backend/tests/test_mitre.py @@ -33,6 +33,14 @@ _FIXTURE_BUNDLE = { ], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, + { + "type": "attack-pattern", + "name": "Python", + "external_references": [ + {"source_name": "mitre-attack", "external_id": "T1059.006"} + ], + "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], + }, { "type": "attack-pattern", "name": "Phishing", @@ -76,9 +84,15 @@ def _reset_mitre(): """Reset the MITRE service state between tests.""" original_loaded = mitre_svc.mitre_loaded original_index = list(mitre_svc._index) + original_tactics = dict(mitre_svc._tactics_by_technique) + original_names = dict(mitre_svc._name_by_id) + original_matrix = list(mitre_svc._matrix) yield mitre_svc.mitre_loaded = original_loaded mitre_svc._index = original_index + mitre_svc._tactics_by_technique = original_tactics + mitre_svc._name_by_id = original_names + mitre_svc._matrix = original_matrix @pytest.fixture() @@ -96,7 +110,7 @@ def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path: def test_load_bundle_success(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.mitre_loaded is True - assert len(mitre_svc._index) == 4 # 5 objects minus 1 revoked = 4 + assert len(mitre_svc._index) == 5 # 6 attack-patterns minus 1 revoked = 5 def test_load_bundle_missing_file() -> None: @@ -245,3 +259,119 @@ def test_mitre_endpoint_includes_tactics( phishing = next((r for r in data if r["id"] == "T1566"), None) assert phishing is not None assert "initial-access" in phishing["tactics"] + + +# --------------------------------------------------------------------------- +# Sprint 3: get_tactics, lookup_name, get_matrix +# --------------------------------------------------------------------------- + + +def test_get_tactics_known(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + tactics = mitre_svc.get_tactics("T1078") + assert "initial-access" in tactics + assert "persistence" in tactics + + +def test_get_tactics_unknown_returns_empty(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.get_tactics("T0000") == [] + + +def test_lookup_name_known(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.lookup_name("T1059") == "Command and Scripting Interpreter" + + +def test_lookup_name_subtechnique(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.lookup_name("T1059.001") == "PowerShell" + + +def test_lookup_name_unknown_returns_none(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.lookup_name("T0000") is None + + +def test_get_matrix_returns_ordered_tactics(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + tactic_ids = [t["tactic_id"] for t in matrix] + # initial-access must come before execution in canonical order. + assert tactic_ids.index("initial-access") < tactic_ids.index("execution") + + +def test_get_matrix_subtechniques_nested(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution") + t1059 = next((t for t in exec_tactic["techniques"] if t["id"] == "T1059"), None) + assert t1059 is not None + sub_ids = [s["id"] for s in t1059["subtechniques"]] + assert "T1059.001" in sub_ids + assert "T1059.006" in sub_ids + + +def test_get_matrix_subtechniques_sorted_by_name(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution") + t1059 = next(t for t in exec_tactic["techniques"] if t["id"] == "T1059") + names = [s["name"] for s in t1059["subtechniques"]] + assert names == sorted(names) + + +def test_get_matrix_techniques_sorted_by_name(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access") + names = [t["name"] for t in ia_tactic["techniques"]] + assert names == sorted(names) + + +def test_get_matrix_technique_no_subtechniques(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access") + phishing = next((t for t in ia_tactic["techniques"] if t["id"] == "T1566"), None) + assert phishing is not None + assert phishing["subtechniques"] == [] + + +def test_matrix_endpoint_ok( + client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path +) -> None: + mitre_svc.load_bundle(bundle_file) + resp = client.get("/api/mitre/matrix", headers=_h(redteam_token)) + assert resp.status_code == 200 + data = resp.get_json() + assert isinstance(data, list) + tactic_ids = [t["tactic_id"] for t in data] + assert "initial-access" in tactic_ids + assert "execution" in tactic_ids + + +def test_matrix_endpoint_503_when_not_loaded( + client: FlaskClient, redteam_token: str +) -> None: + mitre_svc.mitre_loaded = False + resp = client.get("/api/mitre/matrix", headers=_h(redteam_token)) + assert resp.status_code == 503 + + +def test_matrix_endpoint_requires_auth(client: FlaskClient) -> None: + resp = client.get("/api/mitre/matrix") + assert resp.status_code == 401 + + +def test_matrix_endpoint_all_roles( + client: FlaskClient, + redteam_token: str, + soc_token: str, + admin_token: str, + bundle_file: pathlib.Path, +) -> None: + mitre_svc.load_bundle(bundle_file) + for token in (redteam_token, soc_token, admin_token): + resp = client.get("/api/mitre/matrix", headers=_h(token)) + assert resp.status_code == 200 diff --git a/backend/tests/test_simulations_techniques.py b/backend/tests/test_simulations_techniques.py new file mode 100644 index 0000000..5212577 --- /dev/null +++ b/backend/tests/test_simulations_techniques.py @@ -0,0 +1,347 @@ +"""Sprint 3 — multi-technique simulation tests (AC-13).""" +from __future__ import annotations + +import json +import pathlib + +import pytest +from flask.testing import FlaskClient + +from backend.app.services import mitre as mitre_svc +from backend.tests.conftest import auth_headers as _h + +# --------------------------------------------------------------------------- +# Minimal STIX fixture (reused from test_mitre.py pattern) +# --------------------------------------------------------------------------- + +_FIXTURE_BUNDLE = { + "type": "bundle", + "objects": [ + { + "type": "attack-pattern", + "name": "Command and Scripting Interpreter", + "external_references": [{"source_name": "mitre-attack", "external_id": "T1059"}], + "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], + }, + { + "type": "attack-pattern", + "name": "PowerShell", + "external_references": [{"source_name": "mitre-attack", "external_id": "T1059.001"}], + "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], + }, + { + "type": "attack-pattern", + "name": "Valid Accounts", + "external_references": [{"source_name": "mitre-attack", "external_id": "T1078"}], + "kill_chain_phases": [ + {"phase_name": "initial-access", "kill_chain_name": "mitre-attack"}, + {"phase_name": "persistence", "kill_chain_name": "mitre-attack"}, + ], + }, + ], +} + + +@pytest.fixture(autouse=True) +def _reset_mitre(): + original_loaded = mitre_svc.mitre_loaded + original_index = list(mitre_svc._index) + original_tactics = dict(mitre_svc._tactics_by_technique) + original_names = dict(mitre_svc._name_by_id) + original_matrix = list(mitre_svc._matrix) + yield + mitre_svc.mitre_loaded = original_loaded + mitre_svc._index = original_index + mitre_svc._tactics_by_technique = original_tactics + mitre_svc._name_by_id = original_names + mitre_svc._matrix = original_matrix + + +@pytest.fixture() +def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path: + p = tmp_path / "enterprise-attack.json" + p.write_text(json.dumps(_FIXTURE_BUNDLE), encoding="utf-8") + return p + + +@pytest.fixture() +def loaded_bundle(bundle_file: pathlib.Path) -> pathlib.Path: + mitre_svc.load_bundle(bundle_file) + return bundle_file + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_engagement(client: FlaskClient, token: str) -> dict: + resp = client.post( + "/api/engagements", + headers=_h(token), + json={"name": "Op Sprint3", "start_date": "2026-06-01"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _make_sim(client: FlaskClient, token: str, eid: int) -> dict: + resp = client.post( + f"/api/engagements/{eid}/simulations", + headers=_h(token), + json={"name": "Technique Test"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _patch(client: FlaskClient, token: str, sid: int, payload: dict): + return client.patch(f"/api/simulations/{sid}", headers=_h(token), json=payload) + + +# --------------------------------------------------------------------------- +# AC-13.1 — new simulation has techniques = [] +# --------------------------------------------------------------------------- + + +def test_new_simulation_has_empty_techniques( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + assert sim["techniques"] == [] + + +# --------------------------------------------------------------------------- +# AC-13.3 — serializer enriches techniques with tactics +# --------------------------------------------------------------------------- + + +def test_techniques_enriched_with_tactics( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]}) + + resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token)) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert len(techs) == 1 + assert techs[0]["id"] == "T1078" + assert "initial-access" in techs[0]["tactics"] + assert "persistence" in techs[0]["tactics"] + + +def test_techniques_with_unknown_id_returns_empty_tactics( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + """If a technique was removed from the bundle after save, tactics gracefully = [].""" + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + # Bypass service, write directly an id not in the bundle. + from backend.app.extensions import db + from backend.app.models.simulation import Simulation + + with client.application.app_context(): + s = db.session.get(Simulation, sim["id"]) + s.techniques = [{"id": "T0000", "name": "Removed Technique"}] + db.session.commit() + + resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token)) + techs = resp.get_json()["techniques"] + assert techs[0]["tactics"] == [] + + +# --------------------------------------------------------------------------- +# AC-13.4 — PATCH technique_ids +# --------------------------------------------------------------------------- + + +def test_patch_technique_ids_sets_techniques( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078"]}) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert len(techs) == 2 + ids = [t["id"] for t in techs] + assert "T1059" in ids + assert "T1078" in ids + + +def test_patch_technique_ids_resolves_name( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 200 + tech = resp.get_json()["techniques"][0] + assert tech["name"] == "Command and Scripting Interpreter" + + +def test_patch_technique_ids_unknown_returns_400( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T9999"]}) + assert resp.status_code == 400 + assert "unknown technique id: T9999" in resp.get_json()["error"] + + +def test_patch_technique_ids_partial_unknown_rejected( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + # One valid, one unknown — whole request rejected. + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T9999"]}) + assert resp.status_code == 400 + + +def test_patch_technique_ids_includes_subtechnique( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059.001"]}) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert techs[0]["id"] == "T1059.001" + assert techs[0]["name"] == "PowerShell" + + +def test_patch_technique_ids_replaces_list( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]}) + assert resp.status_code == 200 + ids = [t["id"] for t in resp.get_json()["techniques"]] + assert ids == ["T1078"] + + +def test_patch_technique_ids_empty_clears_list( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []}) + assert resp.status_code == 200 + assert resp.get_json()["techniques"] == [] + + +def test_patch_technique_ids_not_list_returns_400( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": "T1059"}) + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# Dedup (spec-reviewer note: AC-13.4) +# --------------------------------------------------------------------------- + + +def test_patch_technique_ids_deduplicates( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch( + client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078", "T1059"]} + ) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert len(techs) == 2 + # Order preserved: T1059 first. + assert techs[0]["id"] == "T1059" + assert techs[1]["id"] == "T1078" + + +# --------------------------------------------------------------------------- +# AC-13.5 — auto-transition on technique_ids +# --------------------------------------------------------------------------- + + +def test_technique_ids_non_empty_triggers_auto_transition( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + assert sim["status"] == "pending" + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 200 + assert resp.get_json()["status"] == "in_progress" + + +def test_technique_ids_empty_does_not_trigger_auto_transition( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []}) + assert resp.status_code == 200 + assert resp.get_json()["status"] == "pending" + + +# --------------------------------------------------------------------------- +# SOC cannot patch technique_ids (it's a redteam field) +# --------------------------------------------------------------------------- + + +def test_soc_cannot_patch_technique_ids( + client: FlaskClient, redteam_token: str, soc_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + # Advance to review_required so SOC can touch the simulation at all. + client.post( + f"/api/simulations/{sim['id']}/transition", + headers=_h(redteam_token), + json={"to": "review_required"}, + ) + + resp = _patch(client, soc_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Migration backfill test (inline, no Alembic runner needed) +# --------------------------------------------------------------------------- + + +def test_migration_backfill_logic() -> None: + """Verify the backfill logic used in upgrade(): scalar → [{id, name}].""" + import json as _json + + def _backfill(tech_id, tech_name): + if tech_id: + return _json.loads(_json.dumps([{"id": tech_id, "name": tech_name or ""}])) + return [] + + assert _backfill("T1059", "Command and Scripting Interpreter") == [ + {"id": "T1059", "name": "Command and Scripting Interpreter"} + ] + assert _backfill(None, None) == [] + assert _backfill("T1059", None) == [{"id": "T1059", "name": ""}] From 673b25e0b063ef73bf33559089c94a53ee602da0 Mon Sep 17 00:00:00 2001 From: Knacky Date: Wed, 27 May 2026 03:58:30 +0200 Subject: [PATCH 2/8] fix(backend): PATCH technique_ids returns 503 when MITRE bundle not loaded Added bundle-loaded guard in _resolve_technique_ids() before attempting any lookup; matches behavior of GET /api/mitre/matrix and GET /api/mitre/techniques. Added corresponding test case. Co-Authored-By: Claude Sonnet 4.6 --- backend/app/services/simulation_workflow.py | 3 +++ backend/tests/test_simulations_techniques.py | 21 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/backend/app/services/simulation_workflow.py b/backend/app/services/simulation_workflow.py index 2c8ddd7..2df406d 100644 --- a/backend/app/services/simulation_workflow.py +++ b/backend/app/services/simulation_workflow.py @@ -55,6 +55,9 @@ def _resolve_technique_ids( """ from backend.app.services import mitre as mitre_svc + if not mitre_svc.mitre_loaded: + return None, (jsonify({"error": "mitre bundle not loaded"}), 503) + # Dedup, preserve order. seen: dict[str, None] = dict.fromkeys(technique_ids) resolved: list[dict[str, str]] = [] diff --git a/backend/tests/test_simulations_techniques.py b/backend/tests/test_simulations_techniques.py index 5212577..f515965 100644 --- a/backend/tests/test_simulations_techniques.py +++ b/backend/tests/test_simulations_techniques.py @@ -305,6 +305,27 @@ def test_technique_ids_empty_does_not_trigger_auto_transition( assert resp.get_json()["status"] == "pending" +# --------------------------------------------------------------------------- +# Bundle not loaded — 503 on technique_ids PATCH +# --------------------------------------------------------------------------- + + +def test_patch_technique_ids_bundle_not_loaded_returns_503( + client: FlaskClient, redteam_token: str +) -> None: + """When MITRE bundle is absent, PATCH with technique_ids must return 503.""" + mitre_svc.mitre_loaded = False + mitre_svc._index = [] + mitre_svc._name_by_id = {} + + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 503 + assert resp.get_json()["error"] == "mitre bundle not loaded" + + # --------------------------------------------------------------------------- # SOC cannot patch technique_ids (it's a redteam field) # --------------------------------------------------------------------------- From 771483f3b0c4f3ca0c06d1a0a67cae3f062a3365 Mon Sep 17 00:00:00 2001 From: Knacky Date: Wed, 27 May 2026 04:04:23 +0200 Subject: [PATCH 3/8] =?UTF-8?q?feat(frontend):=20sprint=203=20=E2=80=94=20?= =?UTF-8?q?multi-technique=20MITRE=20selection=20+=20matrix=20modal?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - types: replace mitre_technique_id/name scalars with techniques:MitreTechnique[] on Simulation; add MitreTactic/MitreMatrixTechnique/MitreMatrixSubtechnique; SimulationPatchInput now uses technique_ids:string[] - api/mitre.ts: add getMitreMatrix() → GET /api/mitre/matrix - hooks/useMitre: add useMitreMatrix(enabled) with staleTime:Infinity - MitreTechniquePicker: clean rewrite — onSelect(technique) one-shot, resets input after selection, no incoming value props - MitreTechniqueTag: chip component with id+name and × remove button - MitreMatrixModal: tactic columns (220px fixed), expand/collapse subtechniques, search filter (auto-expands parent on sub match), selection state, focus trap (Tab wrap, Escape, search autofocus), backdrop click cancel, Apply N techniques - MitreTechniquesField: orchestrates tags+picker+matrix with auto-save PATCH on every add/remove/Apply, dedup guard, disabled read-only mode for SOC - SimulationFormPage: swap MitreTechniquePicker for MitreTechniquesField; remove technique state from RT form (techniques have independent auto-save cycle) - SimulationList: MITRE column → T1059 +2 counter format, — when empty - Tests: 84 passing (13 test files); new suites for Tag, Field, Modal; MitreTechniquePicker + SimulationFormPage + SimulationList adapted to new API Co-Authored-By: Claude Sonnet 4.6 --- frontend/src/api/mitre.ts | 7 +- frontend/src/api/types.ts | 23 +- frontend/src/components/MitreMatrixModal.tsx | 344 ++++++++++++++++++ .../src/components/MitreTechniquePicker.tsx | 55 +-- frontend/src/components/MitreTechniqueTag.tsx | 33 ++ .../src/components/MitreTechniquesField.tsx | 135 +++++++ frontend/src/components/SimulationList.tsx | 6 +- frontend/src/hooks/useMitre.ts | 11 +- frontend/src/pages/SimulationFormPage.tsx | 25 +- frontend/tests/MitreMatrixModal.test.tsx | 206 +++++++++++ frontend/tests/MitreTechniquePicker.test.tsx | 127 +------ frontend/tests/MitreTechniqueTag.test.tsx | 41 +++ frontend/tests/MitreTechniquesField.test.tsx | 135 +++++++ frontend/tests/SimulationFormPage.test.tsx | 3 +- frontend/tests/SimulationList.test.tsx | 3 +- 15 files changed, 973 insertions(+), 181 deletions(-) create mode 100644 frontend/src/components/MitreMatrixModal.tsx create mode 100644 frontend/src/components/MitreTechniqueTag.tsx create mode 100644 frontend/src/components/MitreTechniquesField.tsx create mode 100644 frontend/tests/MitreMatrixModal.test.tsx create mode 100644 frontend/tests/MitreTechniqueTag.test.tsx create mode 100644 frontend/tests/MitreTechniquesField.test.tsx diff --git a/frontend/src/api/mitre.ts b/frontend/src/api/mitre.ts index 4750d62..aadd042 100644 --- a/frontend/src/api/mitre.ts +++ b/frontend/src/api/mitre.ts @@ -1,5 +1,5 @@ import { apiClient } from './client'; -import type { MitreTechnique } from './types'; +import type { MitreTactic, MitreTechnique } from './types'; export async function searchMitreTechniques(query: string): Promise { const { data } = await apiClient.get('/mitre/techniques', { @@ -7,3 +7,8 @@ export async function searchMitreTechniques(query: string): Promise { + const { data } = await apiClient.get('/mitre/matrix'); + return data; +} diff --git a/frontend/src/api/types.ts b/frontend/src/api/types.ts index d640768..6c73e3e 100644 --- a/frontend/src/api/types.ts +++ b/frontend/src/api/types.ts @@ -61,12 +61,28 @@ export interface MitreTechnique { tactics: string[]; } +export interface MitreMatrixSubtechnique { + id: string; + name: string; +} + +export interface MitreMatrixTechnique { + id: string; + name: string; + subtechniques: MitreMatrixSubtechnique[]; +} + +export interface MitreTactic { + tactic_id: string; + tactic_name: string; + techniques: MitreMatrixTechnique[]; +} + export interface Simulation { id: number; engagement_id: number; name: string; - mitre_technique_id: string | null; - mitre_technique_name: string | null; + techniques: MitreTechnique[]; description: string | null; commands: string | null; prerequisites: string | null; @@ -88,8 +104,7 @@ export interface SimulationCreateInput { export interface SimulationPatchInput { name?: string; - mitre_technique_id?: string | null; - mitre_technique_name?: string | null; + technique_ids?: string[]; description?: string | null; commands?: string | null; prerequisites?: string | null; diff --git a/frontend/src/components/MitreMatrixModal.tsx b/frontend/src/components/MitreMatrixModal.tsx new file mode 100644 index 0000000..6f170f7 --- /dev/null +++ b/frontend/src/components/MitreMatrixModal.tsx @@ -0,0 +1,344 @@ +import { useEffect, useRef, useState } from 'react'; +import { LoadingState } from './LoadingState'; +import { ErrorState } from './ErrorState'; +import { extractApiError } from '@/api/client'; +import { useMitreMatrix } from '@/hooks/useMitre'; +import type { MitreTechnique } from '@/api/types'; + +interface MitreMatrixModalProps { + isOpen: boolean; + initialSelection: MitreTechnique[]; + onApply: (selection: MitreTechnique[]) => void; + onCancel: () => void; +} + +function techniqueInTactic( + tacticTechniques: { id: string; subtechniques: { id: string }[] }[], + selection: Set, +): number { + let count = 0; + for (const t of tacticTechniques) { + if (selection.has(t.id)) count++; + for (const s of t.subtechniques) { + if (selection.has(s.id)) count++; + } + } + return count; +} + +export function MitreMatrixModal({ + isOpen, + initialSelection, + onApply, + onCancel, +}: MitreMatrixModalProps): JSX.Element | null { + const { data: matrix, isLoading, isError, error } = useMitreMatrix(isOpen); + + // Selected IDs → Map id → {id, name} for reconstruct + const [selectedMap, setSelectedMap] = useState>( + () => new Map(initialSelection.map((t) => [t.id, { id: t.id, name: t.name }])), + ); + const [expandedTechniques, setExpandedTechniques] = useState>(new Set()); + const [search, setSearch] = useState(''); + + const containerRef = useRef(null); + const searchInputRef = useRef(null); + + // Reset local state when modal opens with new initialSelection + useEffect(() => { + if (isOpen) { + setSelectedMap(new Map(initialSelection.map((t) => [t.id, { id: t.id, name: t.name }]))); + setExpandedTechniques(new Set()); + setSearch(''); + } + }, [isOpen]); // eslint-disable-line react-hooks/exhaustive-deps + + // Focus search input on open + useEffect(() => { + if (isOpen) { + // Small delay lets the DOM render before focus + setTimeout(() => searchInputRef.current?.focus(), 0); + } + }, [isOpen]); + + // Escape closes modal + useEffect(() => { + if (!isOpen) return; + const handler = (e: KeyboardEvent) => { + if (e.key === 'Escape') onCancel(); + }; + document.addEventListener('keydown', handler); + return () => document.removeEventListener('keydown', handler); + }, [isOpen, onCancel]); + + const getFocusableElements = () => { + if (!containerRef.current) return []; + return Array.from( + containerRef.current.querySelectorAll( + 'a, button, input, [tabindex]:not([tabindex="-1"])', + ), + ).filter((el) => !(el as HTMLButtonElement | HTMLInputElement).disabled && !el.hidden && el.tabIndex !== -1); + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key !== 'Tab') return; + const focusables = getFocusableElements(); + if (focusables.length === 0) return; + const first = focusables[0]; + const last = focusables[focusables.length - 1]; + if (e.shiftKey) { + if (document.activeElement === first) { + e.preventDefault(); + last.focus(); + } + } else { + if (document.activeElement === last) { + e.preventDefault(); + first.focus(); + } + } + }; + + if (!isOpen) return null; + + const toggleTechnique = (id: string, name: string) => { + setSelectedMap((prev) => { + const next = new Map(prev); + if (next.has(id)) { + next.delete(id); + } else { + next.set(id, { id, name }); + } + return next; + }); + }; + + const toggleExpand = (id: string) => { + setExpandedTechniques((prev) => { + const next = new Set(prev); + if (next.has(id)) { + next.delete(id); + } else { + next.add(id); + } + return next; + }); + }; + + const searchLower = search.toLowerCase().trim(); + + // Figure out which technique IDs should be auto-expanded due to a sub-technique match + const autoExpanded = new Set(); + if (searchLower && matrix) { + for (const tactic of matrix) { + for (const tech of tactic.techniques) { + const subMatch = tech.subtechniques.some( + (s) => s.id.toLowerCase().includes(searchLower) || s.name.toLowerCase().includes(searchLower), + ); + if (subMatch) autoExpanded.add(tech.id); + } + } + } + + const handleApply = () => { + // Reconstruct MitreTechnique[] from selected IDs. + // tactics are not available here; parent will use what it has or send [] + const selection: MitreTechnique[] = Array.from(selectedMap.values()).map((t) => ({ + id: t.id, + name: t.name, + tactics: [], + })); + onApply(selection); + }; + + const totalSelected = selectedMap.size; + + return ( +
+ {/* Backdrop */} + + ); +} diff --git a/frontend/src/components/MitreTechniquePicker.tsx b/frontend/src/components/MitreTechniquePicker.tsx index a86ccbc..5d0106b 100644 --- a/frontend/src/components/MitreTechniquePicker.tsx +++ b/frontend/src/components/MitreTechniquePicker.tsx @@ -1,55 +1,26 @@ -import { - useEffect, - useRef, - useState, - type KeyboardEvent, -} from 'react'; +import { useEffect, useRef, useState, type KeyboardEvent } from 'react'; import { extractApiError } from '@/api/client'; import type { MitreTechnique } from '@/api/types'; import { useMitreSearch } from '@/hooks/useMitre'; interface MitreTechniquePickerProps { - techniqueId: string | null; - techniqueName: string | null; - onChange: (id: string | null, name: string | null) => void; + onSelect: (technique: MitreTechnique) => void; disabled?: boolean; } -function formatOption(t: MitreTechnique): string { - const tacticList = t.tactics.length > 0 ? ` (${t.tactics[0]})` : ''; - return `${t.id} — ${t.name}${tacticList}`; -} - const DEBOUNCE_MS = 200; export function MitreTechniquePicker({ - techniqueId, - techniqueName, - onChange, + onSelect, disabled = false, }: MitreTechniquePickerProps): JSX.Element { - const [inputValue, setInputValue] = useState( - techniqueId && techniqueName ? `${techniqueId} — ${techniqueName}` : '', - ); + const [inputValue, setInputValue] = useState(''); const [query, setQuery] = useState(''); const [open, setOpen] = useState(false); const [activeIndex, setActiveIndex] = useState(-1); const debounceRef = useRef | null>(null); const containerRef = useRef(null); const listRef = useRef(null); - // True once we've synced the first real techniqueId from props (parent/API load). - // After that we stop reacting to null, so keystrokes that emit onChange(null,null) - // don't propagate back and wipe the input mid-stroke. - const hasHydratedFromProps = useRef(false); - - useEffect(() => { - if (techniqueId && techniqueName) { - setInputValue(`${techniqueId} — ${techniqueName}`); - hasHydratedFromProps.current = true; - } else if (!techniqueId && !hasHydratedFromProps.current) { - setInputValue(''); - } - }, [techniqueId, techniqueName]); const { data: results, isFetching, isError, error } = useMitreSearch(query, open); @@ -57,8 +28,6 @@ export function MitreTechniquePicker({ const handleInputChange = (value: string) => { setInputValue(value); - // Clear the selection when user starts typing - onChange(null, null); setOpen(true); setActiveIndex(-1); @@ -69,11 +38,12 @@ export function MitreTechniquePicker({ }; const selectItem = (item: MitreTechnique) => { - setInputValue(formatOption(item)); - onChange(item.id, item.name); + onSelect(item); + // Reset to empty after selection — parent handles append + dedup + setInputValue(''); + setQuery(''); setOpen(false); setActiveIndex(-1); - setQuery(''); }; const handleKeyDown = (e: KeyboardEvent) => { @@ -98,7 +68,6 @@ export function MitreTechniquePicker({ } }; - // Scroll active item into view useEffect(() => { if (activeIndex >= 0 && listRef.current) { const el = listRef.current.children[activeIndex] as HTMLElement | undefined; @@ -106,7 +75,6 @@ export function MitreTechniquePicker({ } }, [activeIndex]); - // Close dropdown on click outside useEffect(() => { const onPointerDown = (e: PointerEvent) => { if (containerRef.current && !containerRef.current.contains(e.target as Node)) { @@ -127,13 +95,11 @@ export function MitreTechniquePicker({ aria-expanded={open} aria-controls={listboxId} aria-activedescendant={activeIndex >= 0 ? `mitre-option-${activeIndex}` : undefined} - aria-label="MITRE technique" + aria-label="Search MITRE technique" className="text-input" value={inputValue} onChange={(e) => handleInputChange(e.target.value)} - onFocus={() => { - if (!techniqueId) setOpen(true); - }} + onFocus={() => setOpen(true)} onKeyDown={handleKeyDown} disabled={disabled} placeholder="Search by ID or name (e.g. T1059)" @@ -174,7 +140,6 @@ export function MitreTechniquePicker({ i === activeIndex ? 'bg-primary-soft text-ink' : 'text-ink hover:bg-cloud' }`} onPointerDown={(e) => { - // Prevent input blur before we handle the click e.preventDefault(); selectItem(item); }} diff --git a/frontend/src/components/MitreTechniqueTag.tsx b/frontend/src/components/MitreTechniqueTag.tsx new file mode 100644 index 0000000..a4d15be --- /dev/null +++ b/frontend/src/components/MitreTechniqueTag.tsx @@ -0,0 +1,33 @@ +import type { MitreTechnique } from '@/api/types'; + +interface MitreTechniqueTagProps { + technique: MitreTechnique; + onRemove: () => void; + disabled?: boolean; +} + +export function MitreTechniqueTag({ + technique, + onRemove, + disabled = false, +}: MitreTechniqueTagProps): JSX.Element { + return ( + + {technique.id} + — {technique.name} + {!disabled && ( + + )} + + ); +} diff --git a/frontend/src/components/MitreTechniquesField.tsx b/frontend/src/components/MitreTechniquesField.tsx new file mode 100644 index 0000000..b454fee --- /dev/null +++ b/frontend/src/components/MitreTechniquesField.tsx @@ -0,0 +1,135 @@ +import { useState } from 'react'; +import { extractApiError } from '@/api/client'; +import type { MitreTechnique } from '@/api/types'; +import { useUpdateSimulation } from '@/hooks/useSimulations'; +import { useToast } from '@/hooks/useToast'; +import { MitreTechniqueTag } from './MitreTechniqueTag'; +import { MitreTechniquePicker } from './MitreTechniquePicker'; +import { MitreMatrixModal } from './MitreMatrixModal'; + +interface MitreTechniquesFieldProps { + value: MitreTechnique[]; + simulationId: number; + engagementId: number; + disabled?: boolean; +} + +export function MitreTechniquesField({ + value, + simulationId, + engagementId, + disabled = false, +}: MitreTechniquesFieldProps): JSX.Element { + const [showMatrix, setShowMatrix] = useState(false); + const [showPicker, setShowPicker] = useState(false); + + const { push } = useToast(); + const updateMutation = useUpdateSimulation(simulationId, engagementId); + + const save = async (techniques: MitreTechnique[]) => { + try { + await updateMutation.mutateAsync({ + technique_ids: techniques.map((t) => t.id), + }); + push('Techniques updated', 'success'); + } catch (err) { + push(extractApiError(err, 'Could not update techniques'), 'error'); + } + }; + + const handleRemove = (id: string) => { + const next = value.filter((t) => t.id !== id); + void save(next); + }; + + const handleSelect = (technique: MitreTechnique) => { + // Dedup: no-op if already present + if (value.some((t) => t.id === technique.id)) return; + const next = [...value, technique]; + void save(next); + }; + + const handleMatrixApply = (selection: MitreTechnique[]) => { + setShowMatrix(false); + // Merge: preserve existing tactics on items already in value, fill from selection otherwise. + // The backend re-enriches tactics at serialize time, so the exact tactics here don't matter. + const merged = selection.map((s) => { + const existing = value.find((v) => v.id === s.id); + return existing ?? s; + }); + void save(merged); + }; + + const isPending = updateMutation.isPending; + + return ( +
+ {/* Tag list */} + {value.length === 0 ? ( +

+ No techniques selected — use the matrix or the quick search to add. +

+ ) : ( +
+ {value.map((t) => ( + handleRemove(t.id)} + disabled={disabled || isPending} + /> + ))} +
+ )} + + {/* Action buttons — hidden in read-only mode */} + {!disabled && ( +
+ + + {isPending && ( + Saving… + )} +
+ )} + + {/* Inline Quick Search picker */} + {showPicker && !disabled && ( +
+ { + handleSelect(technique); + setShowPicker(false); + }} + disabled={isPending} + /> +
+ )} + + {/* Matrix modal */} + setShowMatrix(false)} + /> +
+ ); +} diff --git a/frontend/src/components/SimulationList.tsx b/frontend/src/components/SimulationList.tsx index 210a1f6..ee8618c 100644 --- a/frontend/src/components/SimulationList.tsx +++ b/frontend/src/components/SimulationList.tsx @@ -95,7 +95,11 @@ export function SimulationList({ engagementId }: SimulationListProps): JSX.Eleme - {sim.mitre_technique_id ?? '—'} + {sim.techniques.length === 0 + ? '—' + : sim.techniques.length === 1 + ? sim.techniques[0].id + : `${sim.techniques[0].id} +${sim.techniques.length - 1}`} diff --git a/frontend/src/hooks/useMitre.ts b/frontend/src/hooks/useMitre.ts index 18d736f..f1d8985 100644 --- a/frontend/src/hooks/useMitre.ts +++ b/frontend/src/hooks/useMitre.ts @@ -1,5 +1,5 @@ import { useQuery } from '@tanstack/react-query'; -import { searchMitreTechniques } from '@/api/mitre'; +import { getMitreMatrix, searchMitreTechniques } from '@/api/mitre'; export function useMitreSearch(query: string, enabled: boolean) { return useQuery({ @@ -9,3 +9,12 @@ export function useMitreSearch(query: string, enabled: boolean) { staleTime: 5 * 60 * 1000, }); } + +export function useMitreMatrix(enabled: boolean) { + return useQuery({ + queryKey: ['mitre', 'matrix'], + queryFn: getMitreMatrix, + enabled, + staleTime: Infinity, + }); +} diff --git a/frontend/src/pages/SimulationFormPage.tsx b/frontend/src/pages/SimulationFormPage.tsx index 2e274d9..c852226 100644 --- a/frontend/src/pages/SimulationFormPage.tsx +++ b/frontend/src/pages/SimulationFormPage.tsx @@ -16,12 +16,10 @@ import { LoadingState } from '@/components/LoadingState'; import { ErrorState } from '@/components/ErrorState'; import { SimulationStatusBadge } from '@/components/SimulationStatusBadge'; import { ConfirmDialog } from '@/components/ConfirmDialog'; -import { MitreTechniquePicker } from '@/components/MitreTechniquePicker'; +import { MitreTechniquesField } from '@/components/MitreTechniquesField'; interface RedteamFormState { name: string; - mitre_technique_id: string | null; - mitre_technique_name: string | null; description: string; commands: string; prerequisites: string; @@ -38,8 +36,6 @@ interface SocFormState { const EMPTY_RT: RedteamFormState = { name: '', - mitre_technique_id: null, - mitre_technique_name: null, description: '', commands: '', prerequisites: '', @@ -81,8 +77,6 @@ export function SimulationFormPage(): JSX.Element { const s = detail.data; setRt({ name: s.name, - mitre_technique_id: s.mitre_technique_id, - mitre_technique_name: s.mitre_technique_name, description: s.description ?? '', commands: s.commands ?? '', prerequisites: s.prerequisites ?? '', @@ -154,8 +148,6 @@ export function SimulationFormPage(): JSX.Element { } const patch: SimulationPatchInput = { name: rt.name.trim(), - mitre_technique_id: rt.mitre_technique_id ?? null, - mitre_technique_name: rt.mitre_technique_name ?? null, description: rt.description.trim() || null, commands: rt.commands.trim() || null, prerequisites: rt.prerequisites.trim() || null, @@ -314,16 +306,15 @@ export function SimulationFormPage(): JSX.Element { /> - - - setRt({ ...rt, mitre_technique_id: id, mitre_technique_name: name }) - } +
+ MITRE Techniques + - +