fix(m5): post-review pass — AND filter, advisory lock, N+1, item caps, mutation cache

Spec-reviewer + code-reviewer findings applied:

Must-fix
- Filter combinator AND-semantics: tactic+technique+subtechnique now intersect
  (one IN subquery per facet) instead of being pooled into one OR. Reviewers
  flagged both the wrong default semantics and the theoretical UUID-collision
  risk of pooling tactic/technique/sub UUIDs into a shared list across
  three columns.
- Front-end mutation cache hygiene: updateMeta + setTests both
  `onSettled: invalidate` so a partial failure leaves the cache consistent.

Should-fix
- Per-scenario pg_advisory_xact_lock on set_scenario_tests — serialises
  concurrent reorders, mirrors M4 /mitre/sync pattern.
- Backend/front consistency on duplicate tests in a scenario: the
  UNIQUE(scenario_id, position) constraint already allows the same
  test_template multiple times (chained ops), so the catalogue picker no
  longer excludes already-picked items.

Nice-to-have
- N+1 eradicated in test_template view rendering: _to_views_batch
  builds {uuid → MitreRow} maps in 3 queries up-front; list endpoint
  now issues 4 queries total regardless of list size.
- Wire-level item length caps on tags (64) and expected_iocs (255)
  via Annotated[str, StringConstraints(...)] — returns 400 instead of
  bubbling up StringDataRightTruncation.
- 4 new pytest covering the AND-filter, extra="forbid" rejection,
  empty mitre_tags clearing, and the 65-char tag cap. Total now
  81 pytest + 38 e2e pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Knacky
2026-05-12 20:05:00 +02:00
parent a559823386
commit ce4bd40551
7 changed files with 267 additions and 56 deletions

View File

@@ -12,11 +12,18 @@ import uuid
from typing import Any
from flask import Blueprint, jsonify, request
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel, Field, StringConstraints, ValidationError
from typing import Annotated
from app.core.auth_decorators import require_auth, require_perm
from app.services import test_templates as svc
# Tag and IOC entries are stored as PG ARRAY(String(N)). Cap items at the wire
# layer so over-sized inputs return 400 with a useful message rather than the
# bare StringDataRightTruncation from the driver.
TagStr = Annotated[str, StringConstraints(min_length=1, max_length=64)]
IocStr = Annotated[str, StringConstraints(min_length=1, max_length=255)]
bp = Blueprint("test_templates", __name__, url_prefix="/test-templates")
log = logging.getLogger("metamorph.api.test_templates")
@@ -40,8 +47,8 @@ class CreateTestTemplatePayload(BaseModel):
expected_result_red_md: str | None = Field(default=None, max_length=32_000)
expected_detection_blue_md: str | None = Field(default=None, max_length=32_000)
opsec_level: str = Field(default="medium")
tags: list[str] = Field(default_factory=list, max_length=64)
expected_iocs: list[str] = Field(default_factory=list, max_length=128)
tags: list[TagStr] = Field(default_factory=list, max_length=64)
expected_iocs: list[IocStr] = Field(default_factory=list, max_length=128)
mitre_tags: list[MitreTagIn] = Field(default_factory=list, max_length=64)
model_config = {"extra": "forbid"}
@@ -56,8 +63,8 @@ class UpdateTestTemplatePayload(BaseModel):
expected_result_red_md: str | None = Field(default=None, max_length=32_000)
expected_detection_blue_md: str | None = Field(default=None, max_length=32_000)
opsec_level: str | None = None
tags: list[str] | None = Field(default=None, max_length=64)
expected_iocs: list[str] | None = Field(default=None, max_length=128)
tags: list[TagStr] | None = Field(default=None, max_length=64)
expected_iocs: list[IocStr] | None = Field(default=None, max_length=128)
mitre_tags: list[MitreTagIn] | None = Field(default=None, max_length=64)
model_config = {"extra": "forbid"}

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import func, or_, select
from sqlalchemy import func, or_, select, text
from sqlalchemy.orm import Session, selectinload
_UNSET: Any = object()
@@ -208,8 +208,20 @@ def set_scenario_tests(
scenario_id: uuid.UUID,
test_template_ids: list[uuid.UUID],
) -> ScenarioTemplateView:
"""Replace the entire ordered test list. `position` becomes the index."""
"""Replace the entire ordered test list. `position` becomes the index.
Acquires a per-scenario advisory lock to serialise concurrent reorders.
Without it, two parallel `PUT /scenario-templates/{id}/tests` calls would
race on the wipe-then-insert sequence and deadlock on the UNIQUE(position)
constraint under READ COMMITTED. Mirrors the M4 pattern on /mitre/sync.
"""
with session_scope() as s:
# Lock keyed on the scenario UUID — different scenarios don't block
# each other. Two-int form: high-32 = constant, low-32 = hash of UUID.
s.execute(
text("SELECT pg_advisory_xact_lock(:n, :m)"),
{"n": 0x5C3, "m": hash(scenario_id) & 0xFFFFFFFF},
)
sc = s.get(ScenarioTemplate, scenario_id)
if sc is None or sc.deleted_at is not None:
raise ScenarioTemplateNotFound()

View File

@@ -157,24 +157,122 @@ def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplat
return rows
def _resolve_mitre_views(s: Session, tags: list[TestTemplateMitreTag]) -> list[MitreTagView]:
"""Batch-resolve polymorphic MITRE FKs into MitreTagViews in 3 queries
total — one per kind — regardless of how many tags or templates the
caller is rendering.
"""
tactic_ids = {t.tactic_id for t in tags if t.mitre_kind == "tactic" and t.tactic_id is not None}
technique_ids = {t.technique_id for t in tags if t.mitre_kind == "technique" and t.technique_id is not None}
sub_ids = {t.subtechnique_id for t in tags if t.mitre_kind == "subtechnique" and t.subtechnique_id is not None}
tactic_map: dict[uuid.UUID, MitreTactic] = {}
technique_map: dict[uuid.UUID, MitreTechnique] = {}
sub_map: dict[uuid.UUID, MitreSubtechnique] = {}
if tactic_ids:
tactic_map = {row.id: row for row in s.scalars(select(MitreTactic).where(MitreTactic.id.in_(tactic_ids))).all()}
if technique_ids:
technique_map = {
row.id: row
for row in s.scalars(select(MitreTechnique).where(MitreTechnique.id.in_(technique_ids))).all()
}
if sub_ids:
sub_map = {
row.id: row
for row in s.scalars(select(MitreSubtechnique).where(MitreSubtechnique.id.in_(sub_ids))).all()
}
views: list[MitreTagView] = []
for tag in tags:
if tag.mitre_kind == "tactic" and tag.tactic_id in tactic_map:
row_t = tactic_map[tag.tactic_id]
views.append(MitreTagView(kind="tactic", external_id=row_t.external_id, name=row_t.name, url=row_t.url))
elif tag.mitre_kind == "technique" and tag.technique_id in technique_map:
row_te = technique_map[tag.technique_id]
views.append(MitreTagView(kind="technique", external_id=row_te.external_id, name=row_te.name, url=row_te.url))
elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id in sub_map:
row_sb = sub_map[tag.subtechnique_id]
views.append(MitreTagView(kind="subtechnique", external_id=row_sb.external_id, name=row_sb.name, url=row_sb.url))
views.sort(key=lambda v: (v.kind, v.external_id))
return views
def _to_views_batch(s: Session, templates: list[TestTemplate]) -> list[TestTemplateView]:
"""List-level batcher: one bulk MITRE resolve for all templates' tags.
For a list of K templates with ~T tags each, this issues 3 queries total
(one per MITRE kind) instead of 3K. We build (kind, uuid) → row maps
once, then assemble each template's view in memory.
"""
tactic_ids: set[uuid.UUID] = set()
technique_ids: set[uuid.UUID] = set()
sub_ids: set[uuid.UUID] = set()
for t in templates:
for tag in t.mitre_tags:
if tag.mitre_kind == "tactic" and tag.tactic_id is not None:
tactic_ids.add(tag.tactic_id)
elif tag.mitre_kind == "technique" and tag.technique_id is not None:
technique_ids.add(tag.technique_id)
elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None:
sub_ids.add(tag.subtechnique_id)
tactic_map: dict[uuid.UUID, MitreTactic] = (
{row.id: row for row in s.scalars(select(MitreTactic).where(MitreTactic.id.in_(tactic_ids))).all()}
if tactic_ids
else {}
)
technique_map: dict[uuid.UUID, MitreTechnique] = (
{row.id: row for row in s.scalars(select(MitreTechnique).where(MitreTechnique.id.in_(technique_ids))).all()}
if technique_ids
else {}
)
sub_map: dict[uuid.UUID, MitreSubtechnique] = (
{row.id: row for row in s.scalars(select(MitreSubtechnique).where(MitreSubtechnique.id.in_(sub_ids))).all()}
if sub_ids
else {}
)
def _views_for(tags: list[TestTemplateMitreTag]) -> list[MitreTagView]:
out: list[MitreTagView] = []
for tag in tags:
if tag.mitre_kind == "tactic" and tag.tactic_id in tactic_map:
row_t = tactic_map[tag.tactic_id]
out.append(MitreTagView(kind="tactic", external_id=row_t.external_id, name=row_t.name, url=row_t.url))
elif tag.mitre_kind == "technique" and tag.technique_id in technique_map:
row_te = technique_map[tag.technique_id]
out.append(MitreTagView(kind="technique", external_id=row_te.external_id, name=row_te.name, url=row_te.url))
elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id in sub_map:
row_sb = sub_map[tag.subtechnique_id]
out.append(MitreTagView(kind="subtechnique", external_id=row_sb.external_id, name=row_sb.name, url=row_sb.url))
out.sort(key=lambda v: (v.kind, v.external_id))
return out
views: list[TestTemplateView] = []
for t in templates:
views.append(
TestTemplateView(
id=t.id,
name=t.name,
description=t.description,
objective=t.objective,
procedure_md=t.procedure_md,
prerequisites_md=t.prerequisites_md,
expected_result_red_md=t.expected_result_red_md,
expected_detection_blue_md=t.expected_detection_blue_md,
opsec_level=t.opsec_level,
tags=list(t.tags or []),
expected_iocs=list(t.expected_iocs or []),
mitre_tags=_views_for(list(t.mitre_tags)),
deleted_at=t.deleted_at,
created_at=t.created_at,
updated_at=t.updated_at,
)
)
return views
def _to_view(s: Session, t: TestTemplate) -> TestTemplateView:
tag_views: list[MitreTagView] = []
for tag in t.mitre_tags:
if tag.mitre_kind == "tactic" and tag.tactic_id is not None:
row = s.get(MitreTactic, tag.tactic_id)
if row is not None:
tag_views.append(MitreTagView(kind="tactic", external_id=row.external_id, name=row.name, url=row.url))
elif tag.mitre_kind == "technique" and tag.technique_id is not None:
row = s.get(MitreTechnique, tag.technique_id)
if row is not None:
tag_views.append(MitreTagView(kind="technique", external_id=row.external_id, name=row.name, url=row.url))
elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None:
row = s.get(MitreSubtechnique, tag.subtechnique_id)
if row is not None:
tag_views.append(
MitreTagView(kind="subtechnique", external_id=row.external_id, name=row.name, url=row.url)
)
tag_views.sort(key=lambda v: (v.kind, v.external_id))
tag_views = _resolve_mitre_views(s, list(t.mitre_tags))
return TestTemplateView(
id=t.id,
name=t.name,
@@ -232,41 +330,43 @@ def list_test_templates(
stmt = stmt.where(TestTemplate.tags.any(tag))
count_stmt = count_stmt.where(TestTemplate.tags.any(tag))
# MITRE facet: resolve external_id → uuid then filter via join subquery.
if tactic or technique or subtechnique:
tag_ids: list[uuid.UUID] = []
if tactic:
tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic))
if tac is None:
return [], 0
tag_ids.append(tac.id)
if technique:
tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique))
if tech is None:
return [], 0
tag_ids.append(tech.id)
if subtechnique:
sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique))
if sub is None:
return [], 0
tag_ids.append(sub.id)
sub_q = (
# MITRE facets: each provided facet (tactic, technique, subtechnique) is
# AND-combined — a template tagged BOTH `TA0006` AND `T1003` matches a
# query with `?tactic=TA0006&technique=T1003`, but a template tagged
# only `TA0006` does NOT. Each facet matches strictly its own column
# (no cross-column UUID collision risk).
def _facet_subquery(column, mitre_id: uuid.UUID):
return (
select(TestTemplateMitreTag.test_template_id)
.where(
or_(
TestTemplateMitreTag.tactic_id.in_(tag_ids),
TestTemplateMitreTag.technique_id.in_(tag_ids),
TestTemplateMitreTag.subtechnique_id.in_(tag_ids),
)
)
.where(column == mitre_id)
.distinct()
)
if tactic:
tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic))
if tac is None:
return [], 0
sub_q = _facet_subquery(TestTemplateMitreTag.tactic_id, tac.id)
stmt = stmt.where(TestTemplate.id.in_(sub_q))
count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q))
if technique:
tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique))
if tech is None:
return [], 0
sub_q = _facet_subquery(TestTemplateMitreTag.technique_id, tech.id)
stmt = stmt.where(TestTemplate.id.in_(sub_q))
count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q))
if subtechnique:
sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique))
if sub is None:
return [], 0
sub_q = _facet_subquery(TestTemplateMitreTag.subtechnique_id, sub.id)
stmt = stmt.where(TestTemplate.id.in_(sub_q))
count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q))
total = s.scalar(count_stmt) or 0
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
return [_to_view(s, t) for t in rows], int(total)
return _to_views_batch(s, list(rows)), int(total)
def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView:

View File

@@ -490,3 +490,78 @@ def test_scenario_perm_required(client, admin_token):
_, eve_token = _bootstrap_user_without_perms(client, admin_token, "scn-eve")
r = client.get("/api/v1/scenario-templates", headers=_bearer(eve_token))
assert r.status_code == 403
# === Post-review fixes ======================================================
def test_list_filter_combines_facets_with_and_semantics(client, admin_token):
"""A template tagged only `TA0002` is NOT in `?tactic=TA0002&technique=T1059`.
Pre-fix the OR-combined query would return it. AND-combined semantics
(one IN subquery per facet) restrict the set to templates matching ALL
requested facets.
"""
a = _make_test(
client,
admin_token,
name="and-tactic-only",
mitre_tags=[{"kind": "tactic", "external_id": "TA0002"}],
)
b = _make_test(
client,
admin_token,
name="and-both-tags",
mitre_tags=[
{"kind": "tactic", "external_id": "TA0002"},
{"kind": "technique", "external_id": "T1059"},
],
)
r = client.get(
"/api/v1/test-templates?tactic=TA0002&technique=T1059",
headers=_bearer(admin_token),
)
assert r.status_code == 200
names = [it["name"] for it in r.get_json()["items"]]
assert "and-both-tags" in names
assert "and-tactic-only" not in names
_ = a, b # silence unused vars from linter
def test_create_test_template_rejects_extra_fields(client, admin_token):
"""`model_config = {"extra": "forbid"}` — unknown fields must 400."""
r = client.post(
"/api/v1/test-templates",
headers=_bearer(admin_token),
json={"name": "extra-test", "rogue_field": "smuggled"},
)
assert r.status_code == 400
def test_update_test_template_explicit_empty_mitre_clears(client, admin_token):
"""`PUT { mitre_tags: [] }` is an explicit clear, not a no-op."""
body = _make_test(
client,
admin_token,
name="clear-tags",
mitre_tags=[{"kind": "technique", "external_id": "T1059"}],
)
assert len(body["mitre_tags"]) == 1
r = client.put(
f"/api/v1/test-templates/{body['id']}",
headers=_bearer(admin_token),
json={"mitre_tags": []},
)
assert r.status_code == 200
assert r.get_json()["mitre_tags"] == []
def test_tag_item_length_capped_at_64(client, admin_token):
"""Individual `tags` items must be ≤ 64 chars at the wire layer."""
long_tag = "x" * 65
r = client.post(
"/api/v1/test-templates",
headers=_bearer(admin_token),
json={"name": "long-tag", "tags": [long_tag]},
)
assert r.status_code == 400