diff --git a/backend/app/api/simulations.py b/backend/app/api/simulations.py index c6fd5cb..d5a2f2a 100644 --- a/backend/app/api/simulations.py +++ b/backend/app/api/simulations.py @@ -121,7 +121,7 @@ def transition_simulation(sid: int): # --------------------------------------------------------------------------- -# MITRE autocomplete +# MITRE autocomplete + matrix # --------------------------------------------------------------------------- @@ -136,3 +136,14 @@ def mitre_techniques(): q = request.args.get("q", "").strip() results = mitre_svc.search(q) return jsonify(results), 200 + + +@simulations_bp.get("/api/mitre/matrix") +@login_required +def mitre_matrix(): + from backend.app.services import mitre as mitre_svc + + if not mitre_svc.mitre_loaded: + return jsonify({"error": "mitre bundle not loaded"}), 503 + + return jsonify(mitre_svc.get_matrix()), 200 diff --git a/backend/app/models/simulation.py b/backend/app/models/simulation.py index 17b0663..74d99df 100644 --- a/backend/app/models/simulation.py +++ b/backend/app/models/simulation.py @@ -25,8 +25,7 @@ class Simulation(db.Model): # type: ignore[name-defined] index=True, ) name = db.Column(db.String(255), nullable=False) - mitre_technique_id = db.Column(db.String(32), nullable=True) - mitre_technique_name = db.Column(db.String(255), nullable=True) + techniques = db.Column(db.JSON, nullable=False, default=list) description = db.Column(db.Text, nullable=True) commands = db.Column(db.Text, nullable=True) prerequisites = db.Column(db.Text, nullable=True) diff --git a/backend/app/serializers.py b/backend/app/serializers.py index 7985614..d54e9cc 100644 --- a/backend/app/serializers.py +++ b/backend/app/serializers.py @@ -20,13 +20,22 @@ def serialize_user_brief(user: User) -> dict[str, Any]: return {"id": user.id, "username": user.username} +def _enrich_techniques(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Attach tactics to each {id, name} snapshot from the MITRE service.""" + from backend.app.services import mitre as mitre_svc + + return [ + {"id": t["id"], "name": t["name"], "tactics": mitre_svc.get_tactics(t["id"])} + for t in (raw or []) + ] + + def serialize_simulation(simulation: Simulation) -> dict[str, Any]: return { "id": simulation.id, "engagement_id": simulation.engagement_id, "name": simulation.name, - "mitre_technique_id": simulation.mitre_technique_id, - "mitre_technique_name": simulation.mitre_technique_name, + "techniques": _enrich_techniques(simulation.techniques or []), "description": simulation.description, "commands": simulation.commands, "prerequisites": simulation.prerequisites, diff --git a/backend/app/services/mitre.py b/backend/app/services/mitre.py index 1c52498..7e2c983 100644 --- a/backend/app/services/mitre.py +++ b/backend/app/services/mitre.py @@ -8,11 +8,30 @@ from typing import Any logger = logging.getLogger(__name__) -# Absolute path to the committed bundle. _BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json" +# Canonical Enterprise tactic order (12 tactics). +_TACTIC_ORDER = [ + "initial-access", + "execution", + "persistence", + "privilege-escalation", + "defense-evasion", + "credential-access", + "discovery", + "lateral-movement", + "collection", + "command-and-control", + "exfiltration", + "impact", +] + mitre_loaded: bool = False _index: list[dict[str, Any]] = [] +_tactics_by_technique: dict[str, list[str]] = {} +_name_by_id: dict[str, str] = {} +# matrix: list of tactic dicts (built once at load time) +_matrix: list[dict[str, Any]] = [] def _extract_tactics(obj: dict[str, Any]) -> list[str]: @@ -20,7 +39,7 @@ def _extract_tactics(obj: dict[str, Any]) -> list[str]: return [ p["phase_name"] for p in phases - if isinstance(p, dict) and "phase_name" in p + if isinstance(p, dict) and "phase_name" in p and p.get("kill_chain_name") == "mitre-attack" ] @@ -31,9 +50,65 @@ def _get_external_id(obj: dict[str, Any]) -> str | None: return None +def _is_subtechnique(tech_id: str) -> bool: + return "." in tech_id + + +def _parent_id(sub_id: str) -> str: + return sub_id.split(".")[0] + + +def _build_matrix(entries: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Build the tactic → techniques → subtechniques tree.""" + # Group top-level techniques by tactic. + tactic_techs: dict[str, list[dict[str, Any]]] = {t: [] for t in _TACTIC_ORDER} + + for entry in entries: + if _is_subtechnique(entry["id"]): + continue + for tactic in entry["tactics"]: + if tactic in tactic_techs: + tactic_techs[tactic].append(entry) + + # Attach sub-techniques to their parents. + parent_subs: dict[str, list[dict[str, Any]]] = {} + for entry in entries: + if not _is_subtechnique(entry["id"]): + continue + pid = _parent_id(entry["id"]) + parent_subs.setdefault(pid, []).append({"id": entry["id"], "name": entry["name"]}) + + # Sort subs alphabetically by name. + for subs in parent_subs.values(): + subs.sort(key=lambda x: x["name"]) + + matrix: list[dict[str, Any]] = [] + for tactic_id in _TACTIC_ORDER: + techs = tactic_techs.get(tactic_id, []) + # Sort techniques alphabetically. + techs_sorted = sorted(techs, key=lambda x: x["name"]) + tactic_name = tactic_id.replace("-", " ").title() + matrix.append( + { + "tactic_id": tactic_id, + "tactic_name": tactic_name, + "techniques": [ + { + "id": t["id"], + "name": t["name"], + "subtechniques": parent_subs.get(t["id"], []), + } + for t in techs_sorted + ], + } + ) + + return matrix + + def load_bundle(path: Path | None = None) -> None: """Load the MITRE bundle into memory. Called once at app boot.""" - global mitre_loaded, _index + global mitre_loaded, _index, _tactics_by_technique, _name_by_id, _matrix bundle_path = path or _BUNDLE_PATH try: @@ -49,6 +124,9 @@ def load_bundle(path: Path | None = None) -> None: return entries: list[dict[str, Any]] = [] + tactics_map: dict[str, list[str]] = {} + name_map: dict[str, str] = {} + for obj in data.get("objects") or []: if not isinstance(obj, dict): continue @@ -59,19 +137,35 @@ def load_bundle(path: Path | None = None) -> None: ext_id = _get_external_id(obj) if not ext_id: continue - entries.append( - { - "id": ext_id, - "name": obj.get("name", ""), - "tactics": _extract_tactics(obj), - } - ) + tactics = _extract_tactics(obj) + name = obj.get("name", "") + entries.append({"id": ext_id, "name": name, "tactics": tactics}) + tactics_map[ext_id] = tactics + name_map[ext_id] = name _index = entries + _tactics_by_technique = tactics_map + _name_by_id = name_map + _matrix = _build_matrix(entries) mitre_loaded = True logger.info("MITRE bundle loaded: %d techniques", len(_index)) +def get_tactics(technique_id: str) -> list[str]: + """Return tactic list for a technique id; empty list if unknown.""" + return _tactics_by_technique.get(technique_id, []) + + +def lookup_name(technique_id: str) -> str | None: + """Return the name for a technique id, or None if not in the bundle.""" + return _name_by_id.get(technique_id) + + +def get_matrix() -> list[dict[str, Any]]: + """Return the full tactic → techniques → subtechniques tree.""" + return _matrix + + def search(query: str, limit: int = 20) -> list[dict[str, Any]]: """Return up to `limit` techniques matching `query`. diff --git a/backend/app/services/simulation_workflow.py b/backend/app/services/simulation_workflow.py index 535ccfd..2c8ddd7 100644 --- a/backend/app/services/simulation_workflow.py +++ b/backend/app/services/simulation_workflow.py @@ -10,11 +10,10 @@ from backend.app.extensions import db from backend.app.models import User from backend.app.models.simulation import Simulation, SimulationStatus +# Fields only admin/redteam may write (excluding technique_ids which is handled separately). REDTEAM_FIELDS = frozenset( { "name", - "mitre_technique_id", - "mitre_technique_name", "description", "commands", "prerequisites", @@ -25,8 +24,6 @@ REDTEAM_FIELDS = frozenset( SOC_FIELDS = frozenset({"log_source", "logs", "soc_comment", "incident_number"}) -# Transitions allowed via POST /transition endpoint (manual only). -# auto pending→in_progress is handled in apply_patch, not here. _ALLOWED_TRANSITIONS: dict[str, dict[str, set[str]]] = { "review_required": { "from": {"pending", "in_progress"}, @@ -48,6 +45,27 @@ def _is_non_empty(value: Any) -> bool: return not (isinstance(value, list) and len(value) == 0) +def _resolve_technique_ids( + technique_ids: list[str], +) -> tuple[list[dict[str, str]] | None, tuple[Any, int] | None]: + """Validate and resolve technique IDs to [{id, name}] snapshots. + + Returns (resolved_list, None) on success or (None, error_tuple) on failure. + Deduplicates while preserving order. + """ + from backend.app.services import mitre as mitre_svc + + # Dedup, preserve order. + seen: dict[str, None] = dict.fromkeys(technique_ids) + resolved: list[dict[str, str]] = [] + for tid in seen: + name = mitre_svc.lookup_name(tid) + if name is None: + return None, (jsonify({"error": f"unknown technique id: {tid}"}), 400) + resolved.append({"id": tid, "name": name}) + return resolved, None + + def apply_patch( simulation: Simulation, payload: dict[str, Any], user: User ) -> tuple[Any, int] | None: @@ -59,15 +77,14 @@ def apply_patch( role = user.role.value if role == "soc": - # SOC can only patch when status allows it. if simulation.status not in ( SimulationStatus.REVIEW_REQUIRED, SimulationStatus.DONE, ): return jsonify({"error": "simulation not ready for SOC review"}), 403 - # SOC must not send redteam fields. - redteam_keys_in_payload = REDTEAM_FIELDS & payload.keys() + # SOC must not send redteam fields or technique_ids. + redteam_keys_in_payload = (REDTEAM_FIELDS | {"technique_ids"}) & payload.keys() if redteam_keys_in_payload: return jsonify({"error": "soc cannot edit redteam fields"}), 403 @@ -76,10 +93,10 @@ def apply_patch( setattr(simulation, field, payload[field]) else: - # admin / redteam: apply all fields present. + # admin / redteam path. redteam_keys_present = REDTEAM_FIELDS & payload.keys() - # Validate executed_at before any writes so a bad value causes no partial mutation. + # Validate executed_at upfront before any writes. executed_at_value: datetime | None = None if "executed_at" in redteam_keys_present: val = payload["executed_at"] @@ -91,21 +108,39 @@ def apply_patch( except ValueError: return jsonify({"error": "invalid executed_at"}), 400 + # Validate and resolve technique_ids upfront. + resolved_techniques: list[dict[str, str]] | None = None + if "technique_ids" in payload: + raw_ids = payload["technique_ids"] + if not isinstance(raw_ids, list): + return jsonify({"error": "technique_ids must be a list"}), 400 + resolved_techniques, err = _resolve_technique_ids(raw_ids) + if err is not None: + return err + + # Apply scalar redteam fields. for field in redteam_keys_present: if field == "executed_at": simulation.executed_at = executed_at_value else: setattr(simulation, field, payload[field]) + # Apply resolved techniques. + if resolved_techniques is not None: + simulation.techniques = resolved_techniques + + # Apply SOC fields (admin/redteam may also write them). for field in SOC_FIELDS: if field in payload: setattr(simulation, field, payload[field]) - # Auto-transition pending → in_progress: at least one redteam field with - # a non-empty value in the *incoming payload*. - if simulation.status == SimulationStatus.PENDING and any( - _is_non_empty(payload[k]) for k in redteam_keys_present - ): + # Auto-transition pending → in_progress. + # Triggers when any redteam scalar has a non-empty value, OR technique_ids is non-empty. + auto_trigger = any(_is_non_empty(payload[k]) for k in redteam_keys_present) + if not auto_trigger and "technique_ids" in payload: + auto_trigger = len(payload["technique_ids"]) > 0 + + if simulation.status == SimulationStatus.PENDING and auto_trigger: simulation.status = SimulationStatus.IN_PROGRESS simulation.updated_at = datetime.now(UTC) diff --git a/backend/migrations/versions/0003_simulation_techniques_array.py b/backend/migrations/versions/0003_simulation_techniques_array.py new file mode 100644 index 0000000..84fa165 --- /dev/null +++ b/backend/migrations/versions/0003_simulation_techniques_array.py @@ -0,0 +1,82 @@ +"""replace scalar MITRE columns with techniques JSON array + +Revision ID: 0003 +Revises: 0002 +Create Date: 2026-05-27 00:00:00.000000 +""" +import json + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import column, table, text + + +revision = "0003" +down_revision = "0002" +branch_labels = None +depends_on = None + +# Lightweight table proxies for data migration (no ORM import). +_sims = table( + "simulations", + column("id", sa.Integer), + column("mitre_technique_id", sa.String), + column("mitre_technique_name", sa.String), + column("techniques", sa.Text), +) + + +def upgrade(): + bind = op.get_bind() + + # 1. Add techniques column (nullable while we backfill). + op.add_column("simulations", sa.Column("techniques", sa.Text(), nullable=True)) + + # 2. Backfill: scalar → JSON array. + rows = bind.execute( + text("SELECT id, mitre_technique_id, mitre_technique_name FROM simulations") + ).fetchall() + for row in rows: + if row[1]: # mitre_technique_id is not null + val = json.dumps([{"id": row[1], "name": row[2] or ""}]) + else: + val = "[]" + bind.execute( + text("UPDATE simulations SET techniques = :v WHERE id = :id"), + {"v": val, "id": row[0]}, + ) + + # 3. Make NOT NULL now that every row has a value. + # SQLite doesn't support ALTER COLUMN, so we skip the nullable constraint + # change at DDL level — the application model enforces it. + + # 4. Drop old scalar columns. + with op.batch_alter_table("simulations") as batch_op: + batch_op.drop_column("mitre_technique_id") + batch_op.drop_column("mitre_technique_name") + + +def downgrade(): + bind = op.get_bind() + + # 1. Re-add scalar columns. + with op.batch_alter_table("simulations") as batch_op: + batch_op.add_column(sa.Column("mitre_technique_id", sa.String(length=32), nullable=True)) + batch_op.add_column(sa.Column("mitre_technique_name", sa.String(length=255), nullable=True)) + + # 2. Back-fill: take first element of techniques array. + rows = bind.execute(text("SELECT id, techniques FROM simulations")).fetchall() + for row in rows: + techniques = json.loads(row[1] or "[]") + if techniques: + first = techniques[0] + bind.execute( + text( + "UPDATE simulations SET mitre_technique_id = :tid, mitre_technique_name = :tname WHERE id = :id" + ), + {"tid": first.get("id"), "tname": first.get("name"), "id": row[0]}, + ) + + # 3. Drop techniques column. + with op.batch_alter_table("simulations") as batch_op: + batch_op.drop_column("techniques") diff --git a/backend/tests/test_mitre.py b/backend/tests/test_mitre.py index a31c3cd..2057571 100644 --- a/backend/tests/test_mitre.py +++ b/backend/tests/test_mitre.py @@ -33,6 +33,14 @@ _FIXTURE_BUNDLE = { ], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, + { + "type": "attack-pattern", + "name": "Python", + "external_references": [ + {"source_name": "mitre-attack", "external_id": "T1059.006"} + ], + "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], + }, { "type": "attack-pattern", "name": "Phishing", @@ -76,9 +84,15 @@ def _reset_mitre(): """Reset the MITRE service state between tests.""" original_loaded = mitre_svc.mitre_loaded original_index = list(mitre_svc._index) + original_tactics = dict(mitre_svc._tactics_by_technique) + original_names = dict(mitre_svc._name_by_id) + original_matrix = list(mitre_svc._matrix) yield mitre_svc.mitre_loaded = original_loaded mitre_svc._index = original_index + mitre_svc._tactics_by_technique = original_tactics + mitre_svc._name_by_id = original_names + mitre_svc._matrix = original_matrix @pytest.fixture() @@ -96,7 +110,7 @@ def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path: def test_load_bundle_success(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.mitre_loaded is True - assert len(mitre_svc._index) == 4 # 5 objects minus 1 revoked = 4 + assert len(mitre_svc._index) == 5 # 6 attack-patterns minus 1 revoked = 5 def test_load_bundle_missing_file() -> None: @@ -245,3 +259,119 @@ def test_mitre_endpoint_includes_tactics( phishing = next((r for r in data if r["id"] == "T1566"), None) assert phishing is not None assert "initial-access" in phishing["tactics"] + + +# --------------------------------------------------------------------------- +# Sprint 3: get_tactics, lookup_name, get_matrix +# --------------------------------------------------------------------------- + + +def test_get_tactics_known(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + tactics = mitre_svc.get_tactics("T1078") + assert "initial-access" in tactics + assert "persistence" in tactics + + +def test_get_tactics_unknown_returns_empty(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.get_tactics("T0000") == [] + + +def test_lookup_name_known(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.lookup_name("T1059") == "Command and Scripting Interpreter" + + +def test_lookup_name_subtechnique(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.lookup_name("T1059.001") == "PowerShell" + + +def test_lookup_name_unknown_returns_none(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + assert mitre_svc.lookup_name("T0000") is None + + +def test_get_matrix_returns_ordered_tactics(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + tactic_ids = [t["tactic_id"] for t in matrix] + # initial-access must come before execution in canonical order. + assert tactic_ids.index("initial-access") < tactic_ids.index("execution") + + +def test_get_matrix_subtechniques_nested(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution") + t1059 = next((t for t in exec_tactic["techniques"] if t["id"] == "T1059"), None) + assert t1059 is not None + sub_ids = [s["id"] for s in t1059["subtechniques"]] + assert "T1059.001" in sub_ids + assert "T1059.006" in sub_ids + + +def test_get_matrix_subtechniques_sorted_by_name(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution") + t1059 = next(t for t in exec_tactic["techniques"] if t["id"] == "T1059") + names = [s["name"] for s in t1059["subtechniques"]] + assert names == sorted(names) + + +def test_get_matrix_techniques_sorted_by_name(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access") + names = [t["name"] for t in ia_tactic["techniques"]] + assert names == sorted(names) + + +def test_get_matrix_technique_no_subtechniques(bundle_file: pathlib.Path) -> None: + mitre_svc.load_bundle(bundle_file) + matrix = mitre_svc.get_matrix() + ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access") + phishing = next((t for t in ia_tactic["techniques"] if t["id"] == "T1566"), None) + assert phishing is not None + assert phishing["subtechniques"] == [] + + +def test_matrix_endpoint_ok( + client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path +) -> None: + mitre_svc.load_bundle(bundle_file) + resp = client.get("/api/mitre/matrix", headers=_h(redteam_token)) + assert resp.status_code == 200 + data = resp.get_json() + assert isinstance(data, list) + tactic_ids = [t["tactic_id"] for t in data] + assert "initial-access" in tactic_ids + assert "execution" in tactic_ids + + +def test_matrix_endpoint_503_when_not_loaded( + client: FlaskClient, redteam_token: str +) -> None: + mitre_svc.mitre_loaded = False + resp = client.get("/api/mitre/matrix", headers=_h(redteam_token)) + assert resp.status_code == 503 + + +def test_matrix_endpoint_requires_auth(client: FlaskClient) -> None: + resp = client.get("/api/mitre/matrix") + assert resp.status_code == 401 + + +def test_matrix_endpoint_all_roles( + client: FlaskClient, + redteam_token: str, + soc_token: str, + admin_token: str, + bundle_file: pathlib.Path, +) -> None: + mitre_svc.load_bundle(bundle_file) + for token in (redteam_token, soc_token, admin_token): + resp = client.get("/api/mitre/matrix", headers=_h(token)) + assert resp.status_code == 200 diff --git a/backend/tests/test_simulations_techniques.py b/backend/tests/test_simulations_techniques.py new file mode 100644 index 0000000..5212577 --- /dev/null +++ b/backend/tests/test_simulations_techniques.py @@ -0,0 +1,347 @@ +"""Sprint 3 — multi-technique simulation tests (AC-13).""" +from __future__ import annotations + +import json +import pathlib + +import pytest +from flask.testing import FlaskClient + +from backend.app.services import mitre as mitre_svc +from backend.tests.conftest import auth_headers as _h + +# --------------------------------------------------------------------------- +# Minimal STIX fixture (reused from test_mitre.py pattern) +# --------------------------------------------------------------------------- + +_FIXTURE_BUNDLE = { + "type": "bundle", + "objects": [ + { + "type": "attack-pattern", + "name": "Command and Scripting Interpreter", + "external_references": [{"source_name": "mitre-attack", "external_id": "T1059"}], + "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], + }, + { + "type": "attack-pattern", + "name": "PowerShell", + "external_references": [{"source_name": "mitre-attack", "external_id": "T1059.001"}], + "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], + }, + { + "type": "attack-pattern", + "name": "Valid Accounts", + "external_references": [{"source_name": "mitre-attack", "external_id": "T1078"}], + "kill_chain_phases": [ + {"phase_name": "initial-access", "kill_chain_name": "mitre-attack"}, + {"phase_name": "persistence", "kill_chain_name": "mitre-attack"}, + ], + }, + ], +} + + +@pytest.fixture(autouse=True) +def _reset_mitre(): + original_loaded = mitre_svc.mitre_loaded + original_index = list(mitre_svc._index) + original_tactics = dict(mitre_svc._tactics_by_technique) + original_names = dict(mitre_svc._name_by_id) + original_matrix = list(mitre_svc._matrix) + yield + mitre_svc.mitre_loaded = original_loaded + mitre_svc._index = original_index + mitre_svc._tactics_by_technique = original_tactics + mitre_svc._name_by_id = original_names + mitre_svc._matrix = original_matrix + + +@pytest.fixture() +def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path: + p = tmp_path / "enterprise-attack.json" + p.write_text(json.dumps(_FIXTURE_BUNDLE), encoding="utf-8") + return p + + +@pytest.fixture() +def loaded_bundle(bundle_file: pathlib.Path) -> pathlib.Path: + mitre_svc.load_bundle(bundle_file) + return bundle_file + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_engagement(client: FlaskClient, token: str) -> dict: + resp = client.post( + "/api/engagements", + headers=_h(token), + json={"name": "Op Sprint3", "start_date": "2026-06-01"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _make_sim(client: FlaskClient, token: str, eid: int) -> dict: + resp = client.post( + f"/api/engagements/{eid}/simulations", + headers=_h(token), + json={"name": "Technique Test"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _patch(client: FlaskClient, token: str, sid: int, payload: dict): + return client.patch(f"/api/simulations/{sid}", headers=_h(token), json=payload) + + +# --------------------------------------------------------------------------- +# AC-13.1 — new simulation has techniques = [] +# --------------------------------------------------------------------------- + + +def test_new_simulation_has_empty_techniques( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + assert sim["techniques"] == [] + + +# --------------------------------------------------------------------------- +# AC-13.3 — serializer enriches techniques with tactics +# --------------------------------------------------------------------------- + + +def test_techniques_enriched_with_tactics( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]}) + + resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token)) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert len(techs) == 1 + assert techs[0]["id"] == "T1078" + assert "initial-access" in techs[0]["tactics"] + assert "persistence" in techs[0]["tactics"] + + +def test_techniques_with_unknown_id_returns_empty_tactics( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + """If a technique was removed from the bundle after save, tactics gracefully = [].""" + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + # Bypass service, write directly an id not in the bundle. + from backend.app.extensions import db + from backend.app.models.simulation import Simulation + + with client.application.app_context(): + s = db.session.get(Simulation, sim["id"]) + s.techniques = [{"id": "T0000", "name": "Removed Technique"}] + db.session.commit() + + resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token)) + techs = resp.get_json()["techniques"] + assert techs[0]["tactics"] == [] + + +# --------------------------------------------------------------------------- +# AC-13.4 — PATCH technique_ids +# --------------------------------------------------------------------------- + + +def test_patch_technique_ids_sets_techniques( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078"]}) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert len(techs) == 2 + ids = [t["id"] for t in techs] + assert "T1059" in ids + assert "T1078" in ids + + +def test_patch_technique_ids_resolves_name( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 200 + tech = resp.get_json()["techniques"][0] + assert tech["name"] == "Command and Scripting Interpreter" + + +def test_patch_technique_ids_unknown_returns_400( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T9999"]}) + assert resp.status_code == 400 + assert "unknown technique id: T9999" in resp.get_json()["error"] + + +def test_patch_technique_ids_partial_unknown_rejected( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + # One valid, one unknown — whole request rejected. + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T9999"]}) + assert resp.status_code == 400 + + +def test_patch_technique_ids_includes_subtechnique( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059.001"]}) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert techs[0]["id"] == "T1059.001" + assert techs[0]["name"] == "PowerShell" + + +def test_patch_technique_ids_replaces_list( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]}) + assert resp.status_code == 200 + ids = [t["id"] for t in resp.get_json()["techniques"]] + assert ids == ["T1078"] + + +def test_patch_technique_ids_empty_clears_list( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []}) + assert resp.status_code == 200 + assert resp.get_json()["techniques"] == [] + + +def test_patch_technique_ids_not_list_returns_400( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": "T1059"}) + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# Dedup (spec-reviewer note: AC-13.4) +# --------------------------------------------------------------------------- + + +def test_patch_technique_ids_deduplicates( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch( + client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078", "T1059"]} + ) + assert resp.status_code == 200 + techs = resp.get_json()["techniques"] + assert len(techs) == 2 + # Order preserved: T1059 first. + assert techs[0]["id"] == "T1059" + assert techs[1]["id"] == "T1078" + + +# --------------------------------------------------------------------------- +# AC-13.5 — auto-transition on technique_ids +# --------------------------------------------------------------------------- + + +def test_technique_ids_non_empty_triggers_auto_transition( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + assert sim["status"] == "pending" + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 200 + assert resp.get_json()["status"] == "in_progress" + + +def test_technique_ids_empty_does_not_trigger_auto_transition( + client: FlaskClient, redteam_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + + resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []}) + assert resp.status_code == 200 + assert resp.get_json()["status"] == "pending" + + +# --------------------------------------------------------------------------- +# SOC cannot patch technique_ids (it's a redteam field) +# --------------------------------------------------------------------------- + + +def test_soc_cannot_patch_technique_ids( + client: FlaskClient, redteam_token: str, soc_token: str, loaded_bundle +) -> None: + eng = _make_engagement(client, redteam_token) + sim = _make_sim(client, redteam_token, eng["id"]) + # Advance to review_required so SOC can touch the simulation at all. + client.post( + f"/api/simulations/{sim['id']}/transition", + headers=_h(redteam_token), + json={"to": "review_required"}, + ) + + resp = _patch(client, soc_token, sim["id"], {"technique_ids": ["T1059"]}) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Migration backfill test (inline, no Alembic runner needed) +# --------------------------------------------------------------------------- + + +def test_migration_backfill_logic() -> None: + """Verify the backfill logic used in upgrade(): scalar → [{id, name}].""" + import json as _json + + def _backfill(tech_id, tech_name): + if tech_id: + return _json.loads(_json.dumps([{"id": tech_id, "name": tech_name or ""}])) + return [] + + assert _backfill("T1059", "Command and Scripting Interpreter") == [ + {"id": "T1059", "name": "Command and Scripting Interpreter"} + ] + assert _backfill(None, None) == [] + assert _backfill("T1059", None) == [{"id": "T1059", "name": ""}]