Files
Metamorph/backend/app/api/mitre.py
Knacky 63b48addc0 fix(m4): code-review pass — SSRF allowlist + advisory lock + typed contract
Six post-code-review fixes, applied before opening the PR per project
workflow (spec-review + code-review both gate the merge):

1. SSRF allowlist on `/mitre/sync`. Host must be in MITRE_ALLOWED_HOSTS
   (defaults to `raw.githubusercontent.com`, env-overridable). Closes "admin
   holding `mitre.sync` pivots api container at 169.254.169.254 / internal
   mirrors" via a typo'd URL. New `MitreSourceForbidden` → 400
   `source_forbidden`; checked at the top of `_download()` so it kicks in
   before any I/O.

2. `pg_advisory_xact_lock(hashtext('mitre.seed'))` at the top of the seed
   transaction. Two concurrent `/mitre/sync` requests now serialise across
   the DELETE+INSERT of `mitre_technique_tactics`; previously they could
   both wipe the M2M and one would fail the unique constraint on re-insert.

3. Typed SyncResult contract. Pydantic `SyncResultOut` on the Flask side
   `model_validate`s the dict before returning — single source of truth
   for the response shape, mirrored by a `MitreSyncResult` TS interface
   (next commit). The `as Record<string, unknown>` + `as { duration_ms }`
   cast in MitrePage is gone.

4. N+1 in dotted sub-technique fallback removed. Built
   `{external_id → technique_id}` once at function entry. Currently a
   no-op against MITRE official (0 orphans), but a latent footgun for
   partial / older bundles.

5. `SETTING_VERSION` cleared explicitly when `source != MITRE_DEFAULT_URL`.
   Previously it kept the stale pin label, so `/mitre/status` lied after
   a custom-URL re-sync.

6. `/mitre/sync` 500s no longer echo `str(e)` to the client — URLError /
   psycopg / Pydantic text now lives in the JSON log only. Public response
   stays `{"error": "internal_error"}`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 19:19:11 +02:00

297 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""MITRE ATT&CK reference endpoints.
Read access is open to any authenticated user (the catalogue is reference
data — not sensitive on its own). Sync is admin-only via `mitre.sync`.
"""
from __future__ import annotations
import logging
from flask import Blueprint, jsonify, request
from pydantic import BaseModel
from sqlalchemy import func, or_, select
from app.core.auth_decorators import require_auth, require_perm
from app.db.session import session_scope
from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique
from app.services import mitre_seed as mitre_seed_svc
class SyncResultOut(BaseModel):
"""Response schema for `POST /mitre/sync`. Mirrors `SeedResult.as_dict()`."""
tactics_upserted: int
techniques_upserted: int
subtechniques_upserted: int
subtechniques_skipped_orphan: int
technique_tactic_links: int
version: str | None
source: str
started_at: str
finished_at: str
duration_ms: int
bp = Blueprint("mitre", __name__, url_prefix="/mitre")
log = logging.getLogger("metamorph.api.mitre")
def _pagination_args() -> tuple[int, int] | tuple[None, tuple[int, str]]:
"""Returns (limit, offset) or (None, (status, error_payload))."""
try:
limit = int(request.args.get("limit", "100"))
offset = int(request.args.get("offset", "0"))
except ValueError:
return None, (400, "invalid_pagination")
limit = max(1, min(limit, 500))
offset = max(0, offset)
return limit, offset
def _search(stmt, model, q: str | None):
if not q:
return stmt
like = f"%{q.lower()}%"
return stmt.where(
or_(
func.lower(model.name).like(like),
func.lower(model.external_id).like(like),
)
)
def _serialize_tactic(t: MitreTactic) -> dict:
return {
"id": str(t.id),
"external_id": t.external_id,
"short_name": t.short_name,
"name": t.name,
"description": t.description,
"url": t.url,
}
def _serialize_technique(t: MitreTechnique, *, include_tactics: bool = True) -> dict:
out = {
"id": str(t.id),
"external_id": t.external_id,
"name": t.name,
"description": t.description,
"url": t.url,
}
if include_tactics:
out["tactics"] = sorted(
({"external_id": tac.external_id, "name": tac.name} for tac in t.tactics),
key=lambda d: d["external_id"],
)
return out
def _serialize_subtechnique(sb: MitreSubtechnique) -> dict:
return {
"id": str(sb.id),
"external_id": sb.external_id,
"name": sb.name,
"description": sb.description,
"url": sb.url,
"technique_id": str(sb.technique_id),
}
@bp.get("/tactics")
@require_auth
def list_tactics():
paging = _pagination_args()
if paging[0] is None:
return jsonify({"error": paging[1][1]}), paging[1][0]
limit, offset = paging
q = request.args.get("q") or None
with session_scope() as s:
stmt = select(MitreTactic).order_by(MitreTactic.external_id.asc())
stmt = _search(stmt, MitreTactic, q)
total = s.scalar(_search(select(func.count()).select_from(MitreTactic), MitreTactic, q)) or 0
rows = s.scalars(stmt.limit(limit).offset(offset)).all()
return jsonify(
{
"items": [_serialize_tactic(t) for t in rows],
"total": int(total),
"limit": limit,
"offset": offset,
}
)
@bp.get("/techniques")
@require_auth
def list_techniques():
paging = _pagination_args()
if paging[0] is None:
return jsonify({"error": paging[1][1]}), paging[1][0]
limit, offset = paging
q = request.args.get("q") or None
tactic = request.args.get("tactic") or None
with session_scope() as s:
stmt = select(MitreTechnique).order_by(MitreTechnique.external_id.asc())
count_stmt = select(func.count()).select_from(MitreTechnique)
if tactic:
tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic))
if tac is None:
return jsonify({"items": [], "total": 0, "limit": limit, "offset": offset})
stmt = stmt.join(MitreTechnique.tactics).where(MitreTactic.id == tac.id)
count_stmt = (
count_stmt.select_from(MitreTechnique)
.join(MitreTechnique.tactics)
.where(MitreTactic.id == tac.id)
)
stmt = _search(stmt, MitreTechnique, q)
count_stmt = _search(count_stmt, MitreTechnique, q)
total = s.scalar(count_stmt) or 0
rows = s.scalars(stmt.limit(limit).offset(offset)).all()
return jsonify(
{
"items": [_serialize_technique(t) for t in rows],
"total": int(total),
"limit": limit,
"offset": offset,
}
)
@bp.get("/subtechniques")
@require_auth
def list_subtechniques():
paging = _pagination_args()
if paging[0] is None:
return jsonify({"error": paging[1][1]}), paging[1][0]
limit, offset = paging
q = request.args.get("q") or None
technique = request.args.get("technique") or None
with session_scope() as s:
stmt = select(MitreSubtechnique).order_by(MitreSubtechnique.external_id.asc())
count_stmt = select(func.count()).select_from(MitreSubtechnique)
if technique:
tech = s.scalar(
select(MitreTechnique).where(MitreTechnique.external_id == technique)
)
if tech is None:
return jsonify({"items": [], "total": 0, "limit": limit, "offset": offset})
stmt = stmt.where(MitreSubtechnique.technique_id == tech.id)
count_stmt = count_stmt.where(MitreSubtechnique.technique_id == tech.id)
stmt = _search(stmt, MitreSubtechnique, q)
count_stmt = _search(count_stmt, MitreSubtechnique, q)
total = s.scalar(count_stmt) or 0
rows = s.scalars(stmt.limit(limit).offset(offset)).all()
return jsonify(
{
"items": [_serialize_subtechnique(sb) for sb in rows],
"total": int(total),
"limit": limit,
"offset": offset,
}
)
@bp.get("/matrix")
@require_auth
def matrix():
"""Return the full Enterprise matrix: tactics → techniques → sub-techniques.
One-shot endpoint so the SPA can render the flat attack.mitre.org-style
grid without firing 15 parallel queries. The payload is ~55 KB serialised
against MITRE v19 (15 tactics × ~50 techniques × ~3 subs).
"""
with session_scope() as s:
# All techniques + their tactics (selectin-loaded by the relationship).
techniques = s.scalars(
select(MitreTechnique).order_by(MitreTechnique.external_id.asc())
).all()
# Sub-techniques bucketed by parent.
subs_by_parent: dict = {}
for sb in s.scalars(
select(MitreSubtechnique).order_by(MitreSubtechnique.external_id.asc())
).all():
subs_by_parent.setdefault(sb.technique_id, []).append(
{
"id": str(sb.id),
"external_id": sb.external_id,
"name": sb.name,
}
)
# Tactics in canonical kill-chain order (matches attack.mitre.org).
tactics = s.scalars(
select(MitreTactic).order_by(MitreTactic.external_id.asc())
).all()
# Group techniques by tactic short_name.
techs_by_tactic: dict = {}
for t in techniques:
entry = {
"id": str(t.id),
"external_id": t.external_id,
"name": t.name,
"subtechniques": subs_by_parent.get(t.id, []),
}
for tac in t.tactics:
techs_by_tactic.setdefault(tac.short_name, []).append(entry)
return jsonify(
{
"tactics": [
{
"id": str(t.id),
"external_id": t.external_id,
"short_name": t.short_name,
"name": t.name,
"techniques": techs_by_tactic.get(t.short_name, []),
}
for t in tactics
]
}
)
@bp.get("/status")
@require_auth
def status():
return jsonify(mitre_seed_svc.read_status())
@bp.post("/sync")
@require_auth
@require_perm("mitre.sync")
def sync():
"""Re-pull the configured (or default) STIX source and upsert.
Custom `source` URLs MUST be paired with either `expected_sha256` (integrity
guarantee) or `allow_unverified: true` (explicit opt-out) — the seed service
will raise otherwise. The host is allowlisted (defaults to
raw.githubusercontent.com, overridable via the MITRE_ALLOWED_HOSTS env).
"""
payload = request.get_json(silent=True) or {}
source = payload.get("source") # optional URL override
expected_sha256 = payload.get("expected_sha256")
allow_unverified = bool(payload.get("allow_unverified", False))
try:
result = mitre_seed_svc.seed_mitre(
source=source,
expected_sha256=expected_sha256
or (mitre_seed_svc.MITRE_DEFAULT_SHA256 if source is None else None),
allow_unverified=allow_unverified,
)
except mitre_seed_svc.MitreSourceForbidden as e:
return jsonify({"error": "source_forbidden", "message": str(e)}), 400
except mitre_seed_svc.MitreChecksumMismatch as e:
return jsonify({"error": "checksum_mismatch", "message": str(e)}), 502
except mitre_seed_svc.MitreSeedError as e:
return jsonify({"error": "seed_failed", "message": str(e)}), 502
except Exception: # noqa: BLE001
# Do NOT leak the internal error string to the client (URLError stack,
# DB driver text). The stack lands in our JSON logs.
log.exception("metamorph.api.mitre.sync_failed")
return jsonify({"error": "internal_error"}), 500
# Validate via the Pydantic Out model so the response contract is
# explicit (single source of truth shared with the TS interface).
payload_out = SyncResultOut.model_validate(result.as_dict()).model_dump()
log.info("metamorph.api.mitre.sync_done", extra=payload_out)
return jsonify(payload_out)