"""MITRE ATT&CK bundle loader and search service.""" from __future__ import annotations import json import logging from pathlib import Path from typing import Any logger = logging.getLogger(__name__) # Absolute path to the committed bundle. _BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json" mitre_loaded: bool = False _index: list[dict[str, Any]] = [] def _extract_tactics(obj: dict[str, Any]) -> list[str]: phases = obj.get("kill_chain_phases") or [] return [ p["phase_name"] for p in phases if isinstance(p, dict) and "phase_name" in p ] def _get_external_id(obj: dict[str, Any]) -> str | None: for ref in obj.get("external_references") or []: if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack": return ref.get("external_id") return None def load_bundle(path: Path | None = None) -> None: """Load the MITRE bundle into memory. Called once at app boot.""" global mitre_loaded, _index bundle_path = path or _BUNDLE_PATH try: raw = bundle_path.read_text(encoding="utf-8") data = json.loads(raw) except FileNotFoundError: logger.warning("MITRE bundle not found at %s — autocomplete disabled", bundle_path) mitre_loaded = False return except (json.JSONDecodeError, OSError) as exc: logger.warning("MITRE bundle parse error: %s — autocomplete disabled", exc) mitre_loaded = False return entries: list[dict[str, Any]] = [] for obj in data.get("objects") or []: if not isinstance(obj, dict): continue if obj.get("type") != "attack-pattern": continue if obj.get("revoked") or obj.get("x_mitre_deprecated"): continue ext_id = _get_external_id(obj) if not ext_id: continue entries.append( { "id": ext_id, "name": obj.get("name", ""), "tactics": _extract_tactics(obj), } ) _index = entries mitre_loaded = True logger.info("MITRE bundle loaded: %d techniques", len(_index)) def search(query: str, limit: int = 20) -> list[dict[str, Any]]: """Return up to `limit` techniques matching `query`. Ranking: exact id > prefix id > substring name (case-insensitive). """ q = query.strip().upper() if not q: return [] exact: list[dict[str, Any]] = [] prefix: list[dict[str, Any]] = [] name_match: list[dict[str, Any]] = [] for entry in _index: tech_id = entry["id"].upper() tech_name = entry["name"].upper() if tech_id == q: exact.append(entry) elif tech_id.startswith(q): prefix.append(entry) elif q in tech_name: name_match.append(entry) combined = exact + prefix + name_match return combined[:limit]