"""MITRE ATT&CK reference endpoints. Read access is open to any authenticated user (the catalogue is reference data — not sensitive on its own). Sync is admin-only via `mitre.sync`. """ from __future__ import annotations import logging from flask import Blueprint, jsonify, request from sqlalchemy import func, or_, select from app.core.auth_decorators import require_auth, require_perm from app.db.session import session_scope from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique from app.services import mitre_seed as mitre_seed_svc bp = Blueprint("mitre", __name__, url_prefix="/mitre") log = logging.getLogger("metamorph.api.mitre") def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]: """Returns (limit, offset) or (None, (status, error_payload)).""" try: limit = int(request.args.get("limit", "100")) offset = int(request.args.get("offset", "0")) except ValueError: return None, (400, "invalid_pagination") limit = max(1, min(limit, 500)) offset = max(0, offset) return limit, offset def _search(stmt, model, q: str | None): if not q: return stmt like = f"%{q.lower()}%" return stmt.where( or_( func.lower(model.name).like(like), func.lower(model.external_id).like(like), ) ) def _serialize_tactic(t: MitreTactic) -> dict: return { "id": str(t.id), "external_id": t.external_id, "short_name": t.short_name, "name": t.name, "description": t.description, "url": t.url, } def _serialize_technique(t: MitreTechnique, *, include_tactics: bool = True) -> dict: out = { "id": str(t.id), "external_id": t.external_id, "name": t.name, "description": t.description, "url": t.url, } if include_tactics: out["tactics"] = sorted( ({"external_id": tac.external_id, "name": tac.name} for tac in t.tactics), key=lambda d: d["external_id"], ) return out def _serialize_subtechnique(sb: MitreSubtechnique) -> dict: return { "id": str(sb.id), "external_id": sb.external_id, "name": sb.name, "description": sb.description, "url": sb.url, "technique_id": str(sb.technique_id), } @bp.get("/tactics") @require_auth def list_tactics(): paging = _pagination_args() if paging[0] is None: return jsonify({"error": paging[1][1]}), paging[1][0] limit, offset = paging q = request.args.get("q") or None with session_scope() as s: stmt = select(MitreTactic).order_by(MitreTactic.external_id.asc()) stmt = _search(stmt, MitreTactic, q) total = s.scalar(_search(select(func.count()).select_from(MitreTactic), MitreTactic, q)) or 0 rows = s.scalars(stmt.limit(limit).offset(offset)).all() return jsonify( { "items": [_serialize_tactic(t) for t in rows], "total": int(total), "limit": limit, "offset": offset, } ) @bp.get("/techniques") @require_auth def list_techniques(): paging = _pagination_args() if paging[0] is None: return jsonify({"error": paging[1][1]}), paging[1][0] limit, offset = paging q = request.args.get("q") or None tactic = request.args.get("tactic") or None with session_scope() as s: stmt = select(MitreTechnique).order_by(MitreTechnique.external_id.asc()) count_stmt = select(func.count()).select_from(MitreTechnique) if tactic: tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic)) if tac is None: return jsonify({"items": [], "total": 0, "limit": limit, "offset": offset}) stmt = stmt.join(MitreTechnique.tactics).where(MitreTactic.id == tac.id) count_stmt = ( count_stmt.select_from(MitreTechnique) .join(MitreTechnique.tactics) .where(MitreTactic.id == tac.id) ) stmt = _search(stmt, MitreTechnique, q) count_stmt = _search(count_stmt, MitreTechnique, q) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(limit).offset(offset)).all() return jsonify( { "items": [_serialize_technique(t) for t in rows], "total": int(total), "limit": limit, "offset": offset, } ) @bp.get("/subtechniques") @require_auth def list_subtechniques(): paging = _pagination_args() if paging[0] is None: return jsonify({"error": paging[1][1]}), paging[1][0] limit, offset = paging q = request.args.get("q") or None technique = request.args.get("technique") or None with session_scope() as s: stmt = select(MitreSubtechnique).order_by(MitreSubtechnique.external_id.asc()) count_stmt = select(func.count()).select_from(MitreSubtechnique) if technique: tech = s.scalar( select(MitreTechnique).where(MitreTechnique.external_id == technique) ) if tech is None: return jsonify({"items": [], "total": 0, "limit": limit, "offset": offset}) stmt = stmt.where(MitreSubtechnique.technique_id == tech.id) count_stmt = count_stmt.where(MitreSubtechnique.technique_id == tech.id) stmt = _search(stmt, MitreSubtechnique, q) count_stmt = _search(count_stmt, MitreSubtechnique, q) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(limit).offset(offset)).all() return jsonify( { "items": [_serialize_subtechnique(sb) for sb in rows], "total": int(total), "limit": limit, "offset": offset, } ) @bp.get("/status") @require_auth def status(): return jsonify(mitre_seed_svc.read_status()) @bp.post("/sync") @require_auth @require_perm("mitre.sync") def sync(): """Re-pull the configured (or default) STIX source and upsert. Custom `source` URLs MUST be paired with either `expected_sha256` (integrity guarantee) or `allow_unverified: true` (explicit opt-out) — the seed service will raise otherwise. """ payload = request.get_json(silent=True) or {} source = payload.get("source") # optional URL override expected_sha256 = payload.get("expected_sha256") allow_unverified = bool(payload.get("allow_unverified", False)) try: result = mitre_seed_svc.seed_mitre( source=source, expected_sha256=expected_sha256 or (mitre_seed_svc.MITRE_DEFAULT_SHA256 if source is None else None), allow_unverified=allow_unverified, ) except mitre_seed_svc.MitreChecksumMismatch as e: return jsonify({"error": "checksum_mismatch", "message": str(e)}), 502 except mitre_seed_svc.MitreSeedError as e: return jsonify({"error": "seed_failed", "message": str(e)}), 502 except Exception as e: # noqa: BLE001 log.exception("metamorph.api.mitre.sync_failed") return jsonify({"error": "internal_error", "message": str(e)}), 500 log.warning("metamorph.api.mitre.sync_done", extra=result.as_dict()) return jsonify(result.as_dict())