"""MITRE ATT&CK bundle loader and search service.""" from __future__ import annotations import json import logging from pathlib import Path from typing import Any logger = logging.getLogger(__name__) _BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json" # Canonical Enterprise tactic order (12 tactics). _TACTIC_ORDER = [ "initial-access", "execution", "persistence", "privilege-escalation", "defense-evasion", "credential-access", "discovery", "lateral-movement", "collection", "command-and-control", "exfiltration", "impact", ] # TA-id → short-name mapping (MITRE Enterprise, IDs are not sequential). _TACTIC_IDS: dict[str, str] = { "TA0001": "initial-access", "TA0002": "execution", "TA0003": "persistence", "TA0004": "privilege-escalation", "TA0005": "defense-evasion", "TA0006": "credential-access", "TA0007": "discovery", "TA0008": "lateral-movement", "TA0009": "collection", "TA0011": "command-and-control", "TA0010": "exfiltration", "TA0040": "impact", } TACTIC_NAMES: dict[str, str] = { "initial-access": "Initial Access", "execution": "Execution", "persistence": "Persistence", "privilege-escalation": "Privilege Escalation", "defense-evasion": "Defense Evasion", "credential-access": "Credential Access", "discovery": "Discovery", "lateral-movement": "Lateral Movement", "collection": "Collection", "command-and-control": "Command and Control", "exfiltration": "Exfiltration", "impact": "Impact", } mitre_loaded: bool = False _index: list[dict[str, Any]] = [] _tactics_by_technique: dict[str, list[str]] = {} _name_by_id: dict[str, str] = {} # matrix: list of tactic dicts (built once at load time) _matrix: list[dict[str, Any]] = [] def _extract_tactics(obj: dict[str, Any]) -> list[str]: phases = obj.get("kill_chain_phases") or [] return [ p["phase_name"] for p in phases if isinstance(p, dict) and "phase_name" in p and p.get("kill_chain_name") == "mitre-attack" ] def _get_external_id(obj: dict[str, Any]) -> str | None: for ref in obj.get("external_references") or []: if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack": return ref.get("external_id") return None def _is_subtechnique(tech_id: str) -> bool: return "." in tech_id def _parent_id(sub_id: str) -> str: return sub_id.split(".")[0] def _build_matrix(entries: list[dict[str, Any]]) -> list[dict[str, Any]]: """Build the tactic → techniques → subtechniques tree.""" # Group top-level techniques by tactic. tactic_techs: dict[str, list[dict[str, Any]]] = {t: [] for t in _TACTIC_ORDER} for entry in entries: if _is_subtechnique(entry["id"]): continue for tactic in entry["tactics"]: if tactic in tactic_techs: tactic_techs[tactic].append(entry) # Attach sub-techniques to their parents. parent_subs: dict[str, list[dict[str, Any]]] = {} for entry in entries: if not _is_subtechnique(entry["id"]): continue pid = _parent_id(entry["id"]) parent_subs.setdefault(pid, []).append({"id": entry["id"], "name": entry["name"]}) # Sort subs alphabetically by name. for subs in parent_subs.values(): subs.sort(key=lambda x: x["name"]) matrix: list[dict[str, Any]] = [] for tactic_id in _TACTIC_ORDER: techs = tactic_techs.get(tactic_id, []) # Sort techniques alphabetically. techs_sorted = sorted(techs, key=lambda x: x["name"]) tactic_name = TACTIC_NAMES.get(tactic_id, tactic_id.replace("-", " ").title()) matrix.append( { "tactic_id": tactic_id, "tactic_name": tactic_name, "techniques": [ { "id": t["id"], "name": t["name"], "subtechniques": parent_subs.get(t["id"], []), } for t in techs_sorted ], } ) return matrix def load_bundle(path: Path | None = None) -> None: """Load the MITRE bundle into memory. Called once at app boot.""" global mitre_loaded, _index, _tactics_by_technique, _name_by_id, _matrix bundle_path = path or _BUNDLE_PATH try: raw = bundle_path.read_text(encoding="utf-8") data = json.loads(raw) except FileNotFoundError: logger.warning("MITRE bundle not found at %s — autocomplete disabled", bundle_path) mitre_loaded = False return except (json.JSONDecodeError, OSError) as exc: logger.warning("MITRE bundle parse error: %s — autocomplete disabled", exc) mitre_loaded = False return entries: list[dict[str, Any]] = [] tactics_map: dict[str, list[str]] = {} name_map: dict[str, str] = {} for obj in data.get("objects") or []: if not isinstance(obj, dict): continue if obj.get("type") != "attack-pattern": continue if obj.get("revoked") or obj.get("x_mitre_deprecated"): continue ext_id = _get_external_id(obj) if not ext_id: continue tactics = _extract_tactics(obj) name = obj.get("name", "") entries.append({"id": ext_id, "name": name, "tactics": tactics}) tactics_map[ext_id] = tactics name_map[ext_id] = name _index = entries _tactics_by_technique = tactics_map _name_by_id = name_map _matrix = _build_matrix(entries) mitre_loaded = True logger.info("MITRE bundle loaded: %d techniques", len(_index)) def get_tactics(technique_id: str) -> list[str]: """Return tactic list for a technique id; empty list if unknown.""" return _tactics_by_technique.get(technique_id, []) def lookup_name(technique_id: str) -> str | None: """Return the name for a technique id, or None if not in the bundle.""" return _name_by_id.get(technique_id) def get_matrix() -> list[dict[str, Any]]: """Return the full tactic → techniques → subtechniques tree.""" return _matrix def lookup_tactic(tactic_id: str) -> dict[str, str] | None: """Return {id, name} for a TA-id, or None if unknown.""" short = _TACTIC_IDS.get(tactic_id) if short is None: return None return {"id": tactic_id, "name": TACTIC_NAMES[short]} def get_tactic_name(tactic_id: str) -> str | None: """Return the display name for a TA-id, or None if unknown.""" short = _TACTIC_IDS.get(tactic_id) if short is None: return None return TACTIC_NAMES[short] def search(query: str, limit: int = 20) -> list[dict[str, Any]]: """Return up to `limit` techniques matching `query`. Ranking: exact id > prefix id > substring name (case-insensitive). """ q = query.strip().upper() if not q: return [] exact: list[dict[str, Any]] = [] prefix: list[dict[str, Any]] = [] name_match: list[dict[str, Any]] = [] for entry in _index: tech_id = entry["id"].upper() tech_name = entry["name"].upper() if tech_id == q: exact.append(entry) elif tech_id.startswith(q): prefix.append(entry) elif q in tech_name: name_match.append(entry) combined = exact + prefix + name_match return combined[:limit]