"""MITRE service and endpoint tests. Uses a tiny fixture bundle, not the 40 MB file.""" from __future__ import annotations import json import pathlib import pytest from flask.testing import FlaskClient from backend.app.services import mitre as mitre_svc from backend.tests.conftest import auth_headers as _h # --------------------------------------------------------------------------- # Fixture STIX bundle (minimal, 4 techniques including one sub-technique) # --------------------------------------------------------------------------- _FIXTURE_BUNDLE = { "type": "bundle", "objects": [ { "type": "attack-pattern", "name": "Command and Scripting Interpreter", "external_references": [ {"source_name": "mitre-attack", "external_id": "T1059"} ], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, { "type": "attack-pattern", "name": "PowerShell", "external_references": [ {"source_name": "mitre-attack", "external_id": "T1059.001"} ], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, { "type": "attack-pattern", "name": "Python", "external_references": [ {"source_name": "mitre-attack", "external_id": "T1059.006"} ], "kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}], }, { "type": "attack-pattern", "name": "Phishing", "external_references": [ {"source_name": "mitre-attack", "external_id": "T1566"} ], "kill_chain_phases": [{"phase_name": "initial-access", "kill_chain_name": "mitre-attack"}], }, { "type": "attack-pattern", "name": "Valid Accounts", "external_references": [ {"source_name": "mitre-attack", "external_id": "T1078"} ], "kill_chain_phases": [ {"phase_name": "initial-access", "kill_chain_name": "mitre-attack"}, {"phase_name": "persistence", "kill_chain_name": "mitre-attack"}, ], }, { # Revoked — must be excluded from index. "type": "attack-pattern", "name": "Old Technique", "revoked": True, "external_references": [ {"source_name": "mitre-attack", "external_id": "T9999"} ], "kill_chain_phases": [], }, { "type": "attack-pattern", "name": "Application Layer Protocol", "external_references": [ {"source_name": "mitre-attack", "external_id": "T1071"} ], "kill_chain_phases": [{"phase_name": "command-and-control", "kill_chain_name": "mitre-attack"}], }, { # Not an attack-pattern — must be ignored. "type": "relationship", "name": "Ignored", }, ], } @pytest.fixture(autouse=True) def _reset_mitre(): """Reset the MITRE service state between tests.""" original_loaded = mitre_svc.mitre_loaded original_index = list(mitre_svc._index) original_tactics = dict(mitre_svc._tactics_by_technique) original_names = dict(mitre_svc._name_by_id) original_matrix = list(mitre_svc._matrix) yield mitre_svc.mitre_loaded = original_loaded mitre_svc._index = original_index mitre_svc._tactics_by_technique = original_tactics mitre_svc._name_by_id = original_names mitre_svc._matrix = original_matrix @pytest.fixture() def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path: p = tmp_path / "enterprise-attack.json" p.write_text(json.dumps(_FIXTURE_BUNDLE), encoding="utf-8") return p # --------------------------------------------------------------------------- # Unit tests for load_bundle # --------------------------------------------------------------------------- def test_load_bundle_success(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.mitre_loaded is True assert len(mitre_svc._index) == 6 # 7 attack-patterns minus 1 revoked = 6 def test_load_bundle_missing_file() -> None: mitre_svc.load_bundle(pathlib.Path("/nonexistent/path.json")) assert mitre_svc.mitre_loaded is False def test_load_bundle_invalid_json(tmp_path: pathlib.Path) -> None: bad = tmp_path / "bad.json" bad.write_text("{ not json }", encoding="utf-8") mitre_svc.load_bundle(bad) assert mitre_svc.mitre_loaded is False def test_load_bundle_excludes_revoked(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) ids = [e["id"] for e in mitre_svc._index] assert "T9999" not in ids def test_load_bundle_includes_subtechniques(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) ids = [e["id"] for e in mitre_svc._index] assert "T1059.001" in ids def test_load_bundle_extracts_tactics(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) t1078 = next(e for e in mitre_svc._index if e["id"] == "T1078") assert "initial-access" in t1078["tactics"] assert "persistence" in t1078["tactics"] # --------------------------------------------------------------------------- # Unit tests for search # --------------------------------------------------------------------------- def test_search_exact_id_first(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) results = mitre_svc.search("T1059") assert results[0]["id"] == "T1059" def test_search_prefix_id(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) results = mitre_svc.search("T105") ids = [r["id"] for r in results] assert "T1059" in ids assert "T1059.001" in ids def test_search_name_substring(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) results = mitre_svc.search("phish") assert any(r["id"] == "T1566" for r in results) def test_search_case_insensitive(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) results = mitre_svc.search("POWERSHELL") assert any(r["id"] == "T1059.001" for r in results) def test_search_limit(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) results = mitre_svc.search("T", limit=2) assert len(results) <= 2 def test_search_empty_query(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.search("") == [] def test_search_ranking_order(bundle_file: pathlib.Path) -> None: """exact-id > prefix-id > name match.""" mitre_svc.load_bundle(bundle_file) results = mitre_svc.search("T1059") # T1059 must come before T1059.001 (prefix match) ids = [r["id"] for r in results] assert ids.index("T1059") < ids.index("T1059.001") # --------------------------------------------------------------------------- # Endpoint tests # --------------------------------------------------------------------------- def test_mitre_endpoint_503_when_not_loaded( client: FlaskClient, redteam_token: str ) -> None: mitre_svc.mitre_loaded = False mitre_svc._index = [] resp = client.get("/api/mitre/techniques?q=T1059", headers=_h(redteam_token)) assert resp.status_code == 503 assert resp.get_json()["error"] == "mitre bundle not loaded" def test_mitre_endpoint_returns_results( client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path ) -> None: mitre_svc.load_bundle(bundle_file) resp = client.get("/api/mitre/techniques?q=T1059", headers=_h(redteam_token)) assert resp.status_code == 200 data = resp.get_json() assert isinstance(data, list) assert any(r["id"] == "T1059" for r in data) def test_mitre_endpoint_requires_auth(client: FlaskClient) -> None: resp = client.get("/api/mitre/techniques?q=T1059") assert resp.status_code == 401 def test_mitre_endpoint_all_roles_can_access( client: FlaskClient, redteam_token: str, soc_token: str, admin_token: str, bundle_file: pathlib.Path, ) -> None: mitre_svc.load_bundle(bundle_file) for token in (redteam_token, soc_token, admin_token): resp = client.get("/api/mitre/techniques?q=T1059", headers=_h(token)) assert resp.status_code == 200 def test_mitre_endpoint_max_20_results( client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path ) -> None: mitre_svc.load_bundle(bundle_file) resp = client.get("/api/mitre/techniques?q=T", headers=_h(redteam_token)) assert resp.status_code == 200 assert len(resp.get_json()) <= 20 def test_mitre_endpoint_includes_tactics( client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path ) -> None: mitre_svc.load_bundle(bundle_file) resp = client.get("/api/mitre/techniques?q=T1566", headers=_h(redteam_token)) assert resp.status_code == 200 data = resp.get_json() assert len(data) >= 1 phishing = next((r for r in data if r["id"] == "T1566"), None) assert phishing is not None assert "initial-access" in phishing["tactics"] # --------------------------------------------------------------------------- # Sprint 3: get_tactics, lookup_name, get_matrix # --------------------------------------------------------------------------- def test_get_tactics_known(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) tactics = mitre_svc.get_tactics("T1078") assert "initial-access" in tactics assert "persistence" in tactics def test_get_tactics_unknown_returns_empty(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.get_tactics("T0000") == [] def test_lookup_name_known(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.lookup_name("T1059") == "Command and Scripting Interpreter" def test_lookup_name_subtechnique(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.lookup_name("T1059.001") == "PowerShell" def test_lookup_name_unknown_returns_none(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) assert mitre_svc.lookup_name("T0000") is None def test_get_matrix_returns_ordered_tactics(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() tactic_ids = [t["tactic_id"] for t in matrix] # TA0001 (initial-access) must come before TA0002 (execution) in canonical order. assert tactic_ids.index("TA0001") < tactic_ids.index("TA0002") def test_get_matrix_subtechniques_nested(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() exec_tactic = next(t for t in matrix if t["tactic_id"] == "TA0002") t1059 = next((t for t in exec_tactic["techniques"] if t["id"] == "T1059"), None) assert t1059 is not None sub_ids = [s["id"] for s in t1059["subtechniques"]] assert "T1059.001" in sub_ids assert "T1059.006" in sub_ids def test_get_matrix_subtechniques_sorted_by_name(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() exec_tactic = next(t for t in matrix if t["tactic_id"] == "TA0002") t1059 = next(t for t in exec_tactic["techniques"] if t["id"] == "T1059") names = [s["name"] for s in t1059["subtechniques"]] assert names == sorted(names) def test_get_matrix_techniques_sorted_by_name(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() ia_tactic = next(t for t in matrix if t["tactic_id"] == "TA0001") names = [t["name"] for t in ia_tactic["techniques"]] assert names == sorted(names) def test_get_matrix_technique_no_subtechniques(bundle_file: pathlib.Path) -> None: mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() ia_tactic = next(t for t in matrix if t["tactic_id"] == "TA0001") phishing = next((t for t in ia_tactic["techniques"] if t["id"] == "T1566"), None) assert phishing is not None assert phishing["subtechniques"] == [] def test_matrix_endpoint_ok( client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path ) -> None: mitre_svc.load_bundle(bundle_file) resp = client.get("/api/mitre/matrix", headers=_h(redteam_token)) assert resp.status_code == 200 data = resp.get_json() assert isinstance(data, list) tactic_ids = [t["tactic_id"] for t in data] assert "TA0001" in tactic_ids # initial-access assert "TA0002" in tactic_ids # execution def test_matrix_endpoint_503_when_not_loaded( client: FlaskClient, redteam_token: str ) -> None: mitre_svc.mitre_loaded = False resp = client.get("/api/mitre/matrix", headers=_h(redteam_token)) assert resp.status_code == 503 def test_matrix_endpoint_requires_auth(client: FlaskClient) -> None: resp = client.get("/api/mitre/matrix") assert resp.status_code == 401 def test_matrix_endpoint_all_roles( client: FlaskClient, redteam_token: str, soc_token: str, admin_token: str, bundle_file: pathlib.Path, ) -> None: mitre_svc.load_bundle(bundle_file) for token in (redteam_token, soc_token, admin_token): resp = client.get("/api/mitre/matrix", headers=_h(token)) assert resp.status_code == 200 def test_get_matrix_command_and_control_display_name(bundle_file: pathlib.Path) -> None: """MITRE official name uses lowercase 'and' — not title-cased.""" mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() c2 = next((t for t in matrix if t["tactic_id"] == "TA0011"), None) assert c2 is not None assert c2["tactic_name"] == "Command and Control" def test_get_matrix_tactic_id_is_ta_format(bundle_file: pathlib.Path) -> None: """Matrix tactic_id must use TA-format so frontend can send it back in PATCH tactic_ids.""" mitre_svc.load_bundle(bundle_file) matrix = mitre_svc.get_matrix() for entry in matrix: tid = entry["tactic_id"] assert tid.startswith("TA"), f"tactic_id {tid!r} must be TA-format, not a slug"