feature/m4-mitre #1

Merged
knacky merged 13 commits from feature/m4-mitre into main 2026-05-12 17:24:14 +00:00
2 changed files with 82 additions and 21 deletions
Showing only changes of commit 63b48addc0 - Show all commits

View File

@@ -9,6 +9,7 @@ from __future__ import annotations
import logging import logging
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
from pydantic import BaseModel
from sqlalchemy import func, or_, select from sqlalchemy import func, or_, select
from app.core.auth_decorators import require_auth, require_perm from app.core.auth_decorators import require_auth, require_perm
@@ -16,6 +17,21 @@ from app.db.session import session_scope
from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique
from app.services import mitre_seed as mitre_seed_svc from app.services import mitre_seed as mitre_seed_svc
class SyncResultOut(BaseModel):
"""Response schema for `POST /mitre/sync`. Mirrors `SeedResult.as_dict()`."""
tactics_upserted: int
techniques_upserted: int
subtechniques_upserted: int
subtechniques_skipped_orphan: int
technique_tactic_links: int
version: str | None
source: str
started_at: str
finished_at: str
duration_ms: int
bp = Blueprint("mitre", __name__, url_prefix="/mitre") bp = Blueprint("mitre", __name__, url_prefix="/mitre")
log = logging.getLogger("metamorph.api.mitre") log = logging.getLogger("metamorph.api.mitre")
@@ -248,7 +264,8 @@ def sync():
Custom `source` URLs MUST be paired with either `expected_sha256` (integrity Custom `source` URLs MUST be paired with either `expected_sha256` (integrity
guarantee) or `allow_unverified: true` (explicit opt-out) — the seed service guarantee) or `allow_unverified: true` (explicit opt-out) — the seed service
will raise otherwise. will raise otherwise. The host is allowlisted (defaults to
raw.githubusercontent.com, overridable via the MITRE_ALLOWED_HOSTS env).
""" """
payload = request.get_json(silent=True) or {} payload = request.get_json(silent=True) or {}
source = payload.get("source") # optional URL override source = payload.get("source") # optional URL override
@@ -261,12 +278,19 @@ def sync():
or (mitre_seed_svc.MITRE_DEFAULT_SHA256 if source is None else None), or (mitre_seed_svc.MITRE_DEFAULT_SHA256 if source is None else None),
allow_unverified=allow_unverified, allow_unverified=allow_unverified,
) )
except mitre_seed_svc.MitreSourceForbidden as e:
return jsonify({"error": "source_forbidden", "message": str(e)}), 400
except mitre_seed_svc.MitreChecksumMismatch as e: except mitre_seed_svc.MitreChecksumMismatch as e:
return jsonify({"error": "checksum_mismatch", "message": str(e)}), 502 return jsonify({"error": "checksum_mismatch", "message": str(e)}), 502
except mitre_seed_svc.MitreSeedError as e: except mitre_seed_svc.MitreSeedError as e:
return jsonify({"error": "seed_failed", "message": str(e)}), 502 return jsonify({"error": "seed_failed", "message": str(e)}), 502
except Exception as e: # noqa: BLE001 except Exception: # noqa: BLE001
# Do NOT leak the internal error string to the client (URLError stack,
# DB driver text). The stack lands in our JSON logs.
log.exception("metamorph.api.mitre.sync_failed") log.exception("metamorph.api.mitre.sync_failed")
return jsonify({"error": "internal_error", "message": str(e)}), 500 return jsonify({"error": "internal_error"}), 500
log.warning("metamorph.api.mitre.sync_done", extra=result.as_dict()) # Validate via the Pydantic Out model so the response contract is
return jsonify(result.as_dict()) # explicit (single source of truth shared with the TS interface).
payload_out = SyncResultOut.model_validate(result.as_dict()).model_dump()
log.info("metamorph.api.mitre.sync_done", extra=payload_out)
return jsonify(payload_out)

View File

@@ -29,7 +29,7 @@ from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Iterable from typing import Iterable
from sqlalchemy import delete, select from sqlalchemy import delete, select, text as sql_text
from app.db.session import session_scope from app.db.session import session_scope
from app.models.mitre import ( from app.models.mitre import (
@@ -59,6 +59,18 @@ MITRE_DEFAULT_SHA256 = "df520ea0775a57db7bff760145b02fed89290802913e056b7ed5970b
MITRE_BUNDLE_CACHE_PATH = Path(os.environ.get("MITRE_CACHE_DIR", "/data/mitre")) MITRE_BUNDLE_CACHE_PATH = Path(os.environ.get("MITRE_CACHE_DIR", "/data/mitre"))
MITRE_DOWNLOAD_TIMEOUT_SECONDS = 120 MITRE_DOWNLOAD_TIMEOUT_SECONDS = 120
# Hosts authorised as a source for a MITRE sync. An admin holding `mitre.sync`
# could otherwise pivot the in-container HTTP fetch to internal services
# (169.254.169.254, db, internal mirrors). Override via the `MITRE_ALLOWED_HOSTS`
# env (comma-separated) when running against a private mirror.
MITRE_ALLOWED_HOSTS: frozenset[str] = frozenset(
h.strip()
for h in os.environ.get(
"MITRE_ALLOWED_HOSTS", "raw.githubusercontent.com"
).split(",")
if h.strip()
)
# Settings keys used to expose the seed metadata to the operator UI/CLI. # Settings keys used to expose the seed metadata to the operator UI/CLI.
SETTING_LAST_SYNC = "mitre_last_sync" SETTING_LAST_SYNC = "mitre_last_sync"
SETTING_VERSION = "mitre_version" SETTING_VERSION = "mitre_version"
@@ -76,6 +88,10 @@ class MitreChecksumMismatch(MitreSeedError):
pass pass
class MitreSourceForbidden(MitreSeedError):
"""The provided source URL points to a host outside the allowlist."""
@dataclass @dataclass
class ParsedBundle: class ParsedBundle:
tactics: list[dict] = field(default_factory=list) tactics: list[dict] = field(default_factory=list)
@@ -123,6 +139,18 @@ def _is_url(source: str) -> bool:
return parsed.scheme in ("http", "https") return parsed.scheme in ("http", "https")
def _ensure_host_allowed(url: str) -> None:
"""Raise MitreSourceForbidden if the URL targets a non-allowlisted host."""
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in ("http", "https"):
raise MitreSourceForbidden(f"unsupported URL scheme: {parsed.scheme!r}")
host = (parsed.hostname or "").lower()
if host not in MITRE_ALLOWED_HOSTS:
raise MitreSourceForbidden(
f"host {host!r} not in MITRE_ALLOWED_HOSTS={sorted(MITRE_ALLOWED_HOSTS)}"
)
def _sha256_of(path: Path) -> str: def _sha256_of(path: Path) -> str:
h = hashlib.sha256() h = hashlib.sha256()
with path.open("rb") as f: with path.open("rb") as f:
@@ -132,6 +160,7 @@ def _sha256_of(path: Path) -> str:
def _download(url: str, dest: Path, *, expected_sha256: str | None = None) -> Path: def _download(url: str, dest: Path, *, expected_sha256: str | None = None) -> Path:
_ensure_host_allowed(url)
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
tmp = dest.with_suffix(dest.suffix + ".part") tmp = dest.with_suffix(dest.suffix + ".part")
log.info("metamorph.mitre.download.start", extra={"url": url, "dest": str(dest)}) log.info("metamorph.mitre.download.start", extra={"url": url, "dest": str(dest)})
@@ -331,8 +360,18 @@ def _upsert_subtechniques(
subtechniques: Iterable[dict], subtechniques: Iterable[dict],
stix_to_tech_id: dict, stix_to_tech_id: dict,
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Returns (n_upserted, n_skipped_orphans).""" """Returns (n_upserted, n_skipped_orphans).
`n_upserted` is the count of rows whose state was applied (INSERT or
UPDATE) — matches Postgres upsert semantics.
"""
existing = {sb.external_id: sb for sb in s.scalars(select(MitreSubtechnique)).all()} existing = {sb.external_id: sb for sb in s.scalars(select(MitreSubtechnique)).all()}
# Pre-index techniques by external_id so the dotted-id fallback doesn't
# issue N+1 SELECTs (was a latent footgun for partial-bundle re-syncs).
parent_by_external: dict[str, object] = {
t.external_id: t.id
for t in s.scalars(select(MitreTechnique)).all()
}
n_upserted = 0 n_upserted = 0
n_skipped = 0 n_skipped = 0
for sb in subtechniques: for sb in subtechniques:
@@ -342,17 +381,7 @@ def _upsert_subtechniques(
# Fall back to the dotted external_id convention (T1003.001 → T1003). # Fall back to the dotted external_id convention (T1003.001 → T1003).
m = re.match(r"^(T\d+)\.\d+$", sb["external_id"]) m = re.match(r"^(T\d+)\.\d+$", sb["external_id"])
if m: if m:
parent_ext = m.group(1) parent_id = parent_by_external.get(m.group(1))
# We don't have a parent-by-external-id map here; query.
parent_row = next(
iter(
s.scalars(
select(MitreTechnique).where(MitreTechnique.external_id == parent_ext)
).all()
),
None,
)
parent_id = parent_row.id if parent_row else None
if parent_id is None: if parent_id is None:
log.warning( log.warning(
"metamorph.mitre.orphan_subtechnique", "metamorph.mitre.orphan_subtechnique",
@@ -433,6 +462,13 @@ def seed_mitre(
) )
with session_scope() as s: with session_scope() as s:
# Serialize concurrent /mitre/sync calls. The lock is transaction-scoped
# (released automatically at COMMIT/ROLLBACK), so a second sync arriving
# while the first is mid-DELETE+INSERT of `mitre_technique_tactics`
# blocks until the first commits. Avoids the unique-constraint race the
# code-reviewer flagged. hashtext() is stable across sessions.
s.execute(sql_text("SELECT pg_advisory_xact_lock(hashtext('mitre.seed'))"))
short_to_tactic_id, n_tactics = _upsert_tactics(s, parsed.tactics) short_to_tactic_id, n_tactics = _upsert_tactics(s, parsed.tactics)
stix_to_tech_id, n_techs, n_links = _upsert_techniques( stix_to_tech_id, n_techs, n_links = _upsert_techniques(
s, parsed.techniques, short_to_tactic_id s, parsed.techniques, short_to_tactic_id
@@ -441,10 +477,11 @@ def seed_mitre(
finished_at = datetime.now(tz=timezone.utc) finished_at = datetime.now(tz=timezone.utc)
_upsert_setting(s, SETTING_LAST_SYNC, finished_at.isoformat()) _upsert_setting(s, SETTING_LAST_SYNC, finished_at.isoformat())
# If the URL is the pinned one, we know the version; otherwise leave None. # `version` reflects the known pin only when seeded from MITRE_DEFAULT_URL;
# otherwise we explicitly clear it so /mitre/status doesn't lie about a
# stale version after a custom-URL re-sync.
version = MITRE_VERSION if source_label == MITRE_DEFAULT_URL else None version = MITRE_VERSION if source_label == MITRE_DEFAULT_URL else None
if version: _upsert_setting(s, SETTING_VERSION, version)
_upsert_setting(s, SETTING_VERSION, version)
_upsert_setting(s, SETTING_SOURCE_URL, source_label) _upsert_setting(s, SETTING_SOURCE_URL, source_label)
result = SeedResult( result = SeedResult(