"""MITRE ATT&CK reference endpoints. Read access is open to any authenticated user (the catalogue is reference data — not sensitive on its own). Sync is admin-only via `mitre.sync`. """ from __future__ import annotations import logging from flask import Blueprint, jsonify, request from pydantic import BaseModel from sqlalchemy import func, or_, select from app.core.auth_decorators import require_auth, require_perm from app.db.session import session_scope from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique from app.services import mitre_seed as mitre_seed_svc class SyncResultOut(BaseModel): """Response schema for `POST /mitre/sync`. Mirrors `SeedResult.as_dict()`.""" tactics_upserted: int techniques_upserted: int subtechniques_upserted: int subtechniques_skipped_orphan: int technique_tactic_links: int version: str | None source: str started_at: str finished_at: str duration_ms: int bp = Blueprint("mitre", __name__, url_prefix="/mitre") log = logging.getLogger("metamorph.api.mitre") def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]: """Returns (limit, offset) or (None, (status, error_payload)).""" try: limit = int(request.args.get("limit", "100")) offset = int(request.args.get("offset", "0")) except ValueError: return None, (400, "invalid_pagination") limit = max(1, min(limit, 500)) offset = max(0, offset) return limit, offset def _search(stmt, model, q: str | None): if not q: return stmt like = f"%{q.lower()}%" return stmt.where( or_( func.lower(model.name).like(like), func.lower(model.external_id).like(like), ) ) def _serialize_tactic(t: MitreTactic) -> dict: return { "id": str(t.id), "external_id": t.external_id, "short_name": t.short_name, "name": t.name, "description": t.description, "url": t.url, } def _serialize_technique(t: MitreTechnique, *, include_tactics: bool = True) -> dict: out = { "id": str(t.id), "external_id": t.external_id, "name": t.name, "description": t.description, "url": t.url, } if include_tactics: out["tactics"] = sorted( ({"external_id": tac.external_id, "name": tac.name} for tac in t.tactics), key=lambda d: d["external_id"], ) return out def _serialize_subtechnique(sb: MitreSubtechnique) -> dict: return { "id": str(sb.id), "external_id": sb.external_id, "name": sb.name, "description": sb.description, "url": sb.url, "technique_id": str(sb.technique_id), } @bp.get("/tactics") @require_auth def list_tactics(): paging = _pagination_args() if paging[0] is None: return jsonify({"error": paging[1][1]}), paging[1][0] limit, offset = paging q = request.args.get("q") or None with session_scope() as s: stmt = select(MitreTactic).order_by(MitreTactic.external_id.asc()) stmt = _search(stmt, MitreTactic, q) total = s.scalar(_search(select(func.count()).select_from(MitreTactic), MitreTactic, q)) or 0 rows = s.scalars(stmt.limit(limit).offset(offset)).all() return jsonify( { "items": [_serialize_tactic(t) for t in rows], "total": int(total), "limit": limit, "offset": offset, } ) @bp.get("/techniques") @require_auth def list_techniques(): paging = _pagination_args() if paging[0] is None: return jsonify({"error": paging[1][1]}), paging[1][0] limit, offset = paging q = request.args.get("q") or None tactic = request.args.get("tactic") or None with session_scope() as s: stmt = select(MitreTechnique).order_by(MitreTechnique.external_id.asc()) count_stmt = select(func.count()).select_from(MitreTechnique) if tactic: tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic)) if tac is None: return jsonify({"items": [], "total": 0, "limit": limit, "offset": offset}) stmt = stmt.join(MitreTechnique.tactics).where(MitreTactic.id == tac.id) count_stmt = ( count_stmt.select_from(MitreTechnique) .join(MitreTechnique.tactics) .where(MitreTactic.id == tac.id) ) stmt = _search(stmt, MitreTechnique, q) count_stmt = _search(count_stmt, MitreTechnique, q) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(limit).offset(offset)).all() return jsonify( { "items": [_serialize_technique(t) for t in rows], "total": int(total), "limit": limit, "offset": offset, } ) @bp.get("/subtechniques") @require_auth def list_subtechniques(): paging = _pagination_args() if paging[0] is None: return jsonify({"error": paging[1][1]}), paging[1][0] limit, offset = paging q = request.args.get("q") or None technique = request.args.get("technique") or None with session_scope() as s: stmt = select(MitreSubtechnique).order_by(MitreSubtechnique.external_id.asc()) count_stmt = select(func.count()).select_from(MitreSubtechnique) if technique: tech = s.scalar( select(MitreTechnique).where(MitreTechnique.external_id == technique) ) if tech is None: return jsonify({"items": [], "total": 0, "limit": limit, "offset": offset}) stmt = stmt.where(MitreSubtechnique.technique_id == tech.id) count_stmt = count_stmt.where(MitreSubtechnique.technique_id == tech.id) stmt = _search(stmt, MitreSubtechnique, q) count_stmt = _search(count_stmt, MitreSubtechnique, q) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(limit).offset(offset)).all() return jsonify( { "items": [_serialize_subtechnique(sb) for sb in rows], "total": int(total), "limit": limit, "offset": offset, } ) @bp.get("/matrix") @require_auth def matrix(): """Return the full Enterprise matrix: tactics → techniques → sub-techniques. One-shot endpoint so the SPA can render the flat attack.mitre.org-style grid without firing 15 parallel queries. The payload is ~55 KB serialised against MITRE v19 (15 tactics × ~50 techniques × ~3 subs). """ with session_scope() as s: # All techniques + their tactics (selectin-loaded by the relationship). techniques = s.scalars( select(MitreTechnique).order_by(MitreTechnique.external_id.asc()) ).all() # Sub-techniques bucketed by parent. subs_by_parent: dict = {} for sb in s.scalars( select(MitreSubtechnique).order_by(MitreSubtechnique.external_id.asc()) ).all(): subs_by_parent.setdefault(sb.technique_id, []).append( { "id": str(sb.id), "external_id": sb.external_id, "name": sb.name, } ) # Tactics in canonical kill-chain order (matches attack.mitre.org). tactics = s.scalars( select(MitreTactic).order_by(MitreTactic.external_id.asc()) ).all() # Group techniques by tactic short_name. techs_by_tactic: dict = {} for t in techniques: entry = { "id": str(t.id), "external_id": t.external_id, "name": t.name, "subtechniques": subs_by_parent.get(t.id, []), } for tac in t.tactics: techs_by_tactic.setdefault(tac.short_name, []).append(entry) return jsonify( { "tactics": [ { "id": str(t.id), "external_id": t.external_id, "short_name": t.short_name, "name": t.name, "techniques": techs_by_tactic.get(t.short_name, []), } for t in tactics ] } ) @bp.get("/status") @require_auth def status(): return jsonify(mitre_seed_svc.read_status()) @bp.post("/sync") @require_auth @require_perm("mitre.sync") def sync(): """Re-pull the configured (or default) STIX source and upsert. Custom `source` URLs MUST be paired with either `expected_sha256` (integrity guarantee) or `allow_unverified: true` (explicit opt-out) — the seed service will raise otherwise. The host is allowlisted (defaults to raw.githubusercontent.com, overridable via the MITRE_ALLOWED_HOSTS env). """ payload = request.get_json(silent=True) or {} source = payload.get("source") # optional URL override expected_sha256 = payload.get("expected_sha256") allow_unverified = bool(payload.get("allow_unverified", False)) try: result = mitre_seed_svc.seed_mitre( source=source, expected_sha256=expected_sha256 or (mitre_seed_svc.MITRE_DEFAULT_SHA256 if source is None else None), allow_unverified=allow_unverified, ) except mitre_seed_svc.MitreSourceForbidden as e: return jsonify({"error": "source_forbidden", "message": str(e)}), 400 except mitre_seed_svc.MitreChecksumMismatch as e: return jsonify({"error": "checksum_mismatch", "message": str(e)}), 502 except mitre_seed_svc.MitreSeedError as e: return jsonify({"error": "seed_failed", "message": str(e)}), 502 except Exception: # noqa: BLE001 # Do NOT leak the internal error string to the client (URLError stack, # DB driver text). The stack lands in our JSON logs. log.exception("metamorph.api.mitre.sync_failed") return jsonify({"error": "internal_error"}), 500 # Validate via the Pydantic Out model so the response contract is # explicit (single source of truth shared with the TS interface). payload_out = SyncResultOut.model_validate(result.as_dict()).model_dump() log.info("metamorph.api.mitre.sync_done", extra=payload_out) return jsonify(payload_out)