"""Sprint 3 — multi-technique simulation tests (AC-13).""" from __future__ import annotations import json import pathlib import pytest from flask.testing import FlaskClient from backend.app.services import mitre as mitre_svc from backend.tests.conftest import auth_headers as _h # --------------------------------------------------------------------------- # Minimal STIX fixture (reused from test_mitre.py pattern) # --------------------------------------------------------------------------- _FIXTURE_BUNDLE = { "type": "bundle", "objects": [ { "type": "attack-pattern", "name": "Command and Scripting Interpreter", "external_references": [{"source_name": "mitre-attack", "external_id": "T1059"}], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, { "type": "attack-pattern", "name": "PowerShell", "external_references": [{"source_name": "mitre-attack", "external_id": "T1059.001"}], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, { "type": "attack-pattern", "name": "Valid Accounts", "external_references": [{"source_name": "mitre-attack", "external_id": "T1078"}], "kill_chain_phases": [ {"phase_name": "initial-access", "kill_chain_name": "mitre-attack"}, {"phase_name": "persistence", "kill_chain_name": "mitre-attack"}, ], }, ], } @pytest.fixture(autouse=True) def _reset_mitre(): original_loaded = mitre_svc.mitre_loaded original_index = list(mitre_svc._index) original_tactics = dict(mitre_svc._tactics_by_technique) original_names = dict(mitre_svc._name_by_id) original_matrix = list(mitre_svc._matrix) yield mitre_svc.mitre_loaded = original_loaded mitre_svc._index = original_index mitre_svc._tactics_by_technique = original_tactics mitre_svc._name_by_id = original_names mitre_svc._matrix = original_matrix @pytest.fixture() def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path: p = tmp_path / "enterprise-attack.json" p.write_text(json.dumps(_FIXTURE_BUNDLE), encoding="utf-8") return p @pytest.fixture() def loaded_bundle(bundle_file: pathlib.Path) -> pathlib.Path: mitre_svc.load_bundle(bundle_file) return bundle_file # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_engagement(client: FlaskClient, token: str) -> dict: resp = client.post( "/api/engagements", headers=_h(token), json={"name": "Op Sprint3", "start_date": "2026-06-01"}, ) assert resp.status_code == 201 return resp.get_json() def _make_sim(client: FlaskClient, token: str, eid: int) -> dict: resp = client.post( f"/api/engagements/{eid}/simulations", headers=_h(token), json={"name": "Technique Test"}, ) assert resp.status_code == 201 return resp.get_json() def _patch(client: FlaskClient, token: str, sid: int, payload: dict): return client.patch(f"/api/simulations/{sid}", headers=_h(token), json=payload) # --------------------------------------------------------------------------- # AC-13.1 — new simulation has techniques = [] # --------------------------------------------------------------------------- def test_new_simulation_has_empty_techniques( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) assert sim["techniques"] == [] # --------------------------------------------------------------------------- # AC-13.3 — serializer enriches techniques with tactics # --------------------------------------------------------------------------- def test_techniques_enriched_with_tactics( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]}) resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token)) assert resp.status_code == 200 techs = resp.get_json()["techniques"] assert len(techs) == 1 assert techs[0]["id"] == "T1078" assert "initial-access" in techs[0]["tactics"] assert "persistence" in techs[0]["tactics"] def test_techniques_with_unknown_id_returns_empty_tactics( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: """If a technique was removed from the bundle after save, tactics gracefully = [].""" eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) # Bypass service, write directly an id not in the bundle. from backend.app.extensions import db from backend.app.models.simulation import Simulation with client.application.app_context(): s = db.session.get(Simulation, sim["id"]) s.techniques = [{"id": "T0000", "name": "Removed Technique"}] db.session.commit() resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token)) techs = resp.get_json()["techniques"] assert techs[0]["tactics"] == [] # --------------------------------------------------------------------------- # AC-13.4 — PATCH technique_ids # --------------------------------------------------------------------------- def test_patch_technique_ids_sets_techniques( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078"]}) assert resp.status_code == 200 techs = resp.get_json()["techniques"] assert len(techs) == 2 ids = [t["id"] for t in techs] assert "T1059" in ids assert "T1078" in ids def test_patch_technique_ids_resolves_name( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) assert resp.status_code == 200 tech = resp.get_json()["techniques"][0] assert tech["name"] == "Command and Scripting Interpreter" def test_patch_technique_ids_unknown_returns_400( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T9999"]}) assert resp.status_code == 400 assert "unknown technique id: T9999" in resp.get_json()["error"] def test_patch_technique_ids_partial_unknown_rejected( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) # One valid, one unknown — whole request rejected. resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T9999"]}) assert resp.status_code == 400 def test_patch_technique_ids_includes_subtechnique( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059.001"]}) assert resp.status_code == 200 techs = resp.get_json()["techniques"] assert techs[0]["id"] == "T1059.001" assert techs[0]["name"] == "PowerShell" def test_patch_technique_ids_replaces_list( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]}) assert resp.status_code == 200 ids = [t["id"] for t in resp.get_json()["techniques"]] assert ids == ["T1078"] def test_patch_technique_ids_empty_clears_list( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []}) assert resp.status_code == 200 assert resp.get_json()["techniques"] == [] def test_patch_technique_ids_not_list_returns_400( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": "T1059"}) assert resp.status_code == 400 # --------------------------------------------------------------------------- # Dedup (spec-reviewer note: AC-13.4) # --------------------------------------------------------------------------- def test_patch_technique_ids_deduplicates( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch( client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078", "T1059"]} ) assert resp.status_code == 200 techs = resp.get_json()["techniques"] assert len(techs) == 2 # Order preserved: T1059 first. assert techs[0]["id"] == "T1059" assert techs[1]["id"] == "T1078" # --------------------------------------------------------------------------- # AC-13.5 — auto-transition on technique_ids # --------------------------------------------------------------------------- def test_technique_ids_non_empty_triggers_auto_transition( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) assert sim["status"] == "pending" resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) assert resp.status_code == 200 assert resp.get_json()["status"] == "in_progress" def test_technique_ids_empty_does_not_trigger_auto_transition( client: FlaskClient, redteam_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []}) assert resp.status_code == 200 assert resp.get_json()["status"] == "pending" # --------------------------------------------------------------------------- # Bundle not loaded — 503 on technique_ids PATCH # --------------------------------------------------------------------------- def test_patch_technique_ids_bundle_not_loaded_returns_503( client: FlaskClient, redteam_token: str ) -> None: """When MITRE bundle is absent, PATCH with technique_ids must return 503.""" mitre_svc.mitre_loaded = False mitre_svc._index = [] mitre_svc._name_by_id = {} eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]}) assert resp.status_code == 503 assert resp.get_json()["error"] == "mitre bundle not loaded" # --------------------------------------------------------------------------- # SOC cannot patch technique_ids (it's a redteam field) # --------------------------------------------------------------------------- def test_soc_cannot_patch_technique_ids( client: FlaskClient, redteam_token: str, soc_token: str, loaded_bundle ) -> None: eng = _make_engagement(client, redteam_token) sim = _make_sim(client, redteam_token, eng["id"]) # Advance to review_required so SOC can touch the simulation at all. client.post( f"/api/simulations/{sim['id']}/transition", headers=_h(redteam_token), json={"to": "review_required"}, ) resp = _patch(client, soc_token, sim["id"], {"technique_ids": ["T1059"]}) assert resp.status_code == 403 # --------------------------------------------------------------------------- # Migration backfill test (inline, no Alembic runner needed) # --------------------------------------------------------------------------- def test_migration_backfill_logic() -> None: """Verify the backfill logic used in upgrade(): scalar → [{id, name}].""" import json as _json def _backfill(tech_id, tech_name): if tech_id: return _json.loads(_json.dumps([{"id": tech_id, "name": tech_name or ""}])) return [] assert _backfill("T1059", "Command and Scripting Interpreter") == [ {"id": "T1059", "name": "Command and Scripting Interpreter"} ] assert _backfill(None, None) == [] assert _backfill("T1059", None) == [{"id": "T1059", "name": ""}] def test_migration_0003_techniques_not_null_after_upgrade() -> None: """Run migration 0003 upgrade() against a real SQLite DB and assert techniques is NOT NULL.""" import importlib import json as _json import sqlalchemy as _sa from alembic.operations import Operations from alembic.runtime.migration import MigrationContext engine = _sa.create_engine("sqlite:///:memory:") with engine.begin() as conn: # Create the pre-migration schema (0002 state). conn.execute(_sa.text( "CREATE TABLE simulations (" " id INTEGER PRIMARY KEY," " mitre_technique_id VARCHAR(32)," " mitre_technique_name VARCHAR(255)" ")" )) conn.execute(_sa.text( "INSERT INTO simulations (id, mitre_technique_id, mitre_technique_name)" " VALUES (1, 'T1059', 'Command and Scripting Interpreter')" )) conn.execute(_sa.text( "INSERT INTO simulations (id, mitre_technique_id, mitre_technique_name)" " VALUES (2, NULL, NULL)" )) # Run upgrade() via Alembic Operations context. with engine.begin() as conn: ctx = MigrationContext.configure(conn, opts={"as_sql": False}) ops = Operations(ctx) # Patch the module-level proxy so the migration's op.* calls work. import alembic.op as _op_module _op_module._proxy = ops # type: ignore[attr-defined] spec = importlib.util.spec_from_file_location( "mig_0003", "/home/user/Documents/01_Projects/mimic/.claude/worktrees/sprint-4-ui-polish/backend/migrations/versions/0003_simulation_techniques_array.py", ) assert spec is not None and spec.loader is not None mig = importlib.util.module_from_spec(spec) spec.loader.exec_module(mig) # type: ignore[union-attr] mig.upgrade() # Verify schema: techniques column exists and is NOT NULL. insp = _sa.inspect(engine) cols = {c["name"]: c for c in insp.get_columns("simulations")} assert "techniques" in cols, "techniques column must exist after upgrade" assert cols["techniques"]["nullable"] is False, "techniques must be NOT NULL after upgrade" assert "mitre_technique_id" not in cols assert "mitre_technique_name" not in cols # Verify data was backfilled correctly. with engine.connect() as conn: rows = conn.execute(_sa.text("SELECT id, techniques FROM simulations ORDER BY id")).fetchall() assert _json.loads(rows[0][1]) == [{"id": "T1059", "name": "Command and Scripting Interpreter"}] assert _json.loads(rows[1][1]) == []