Editing a scenario and saving (with or without changes) returned 500: function pg_advisory_xact_lock(smallint, bigint) does not exist Postgres only ships (int4, int4) and (bigint) variants. The two-arg call passed `m = hash(uuid) & 0xFFFFFFFF` which can reach 2^32-1, so psycopg promoted it to bigint and no overload matched. Switched to the single-arg bigint form. While there, replaced Python's built-in hash() with hashlib.blake2b(...) — the built-in is randomised per process via PYTHONHASHSEED, so gunicorn workers were computing different lock keys for the same scenario and the lock wasn't actually serialising across workers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
260 lines
8.6 KiB
Python
260 lines
8.6 KiB
Python
"""CRUD service for `scenario_templates` + their ordered test list.
|
|
|
|
Re-ordering is implemented as **full delete + re-insert** of the
|
|
`scenario_template_tests` rows. The UNIQUE (scenario_template_id, position)
|
|
constraint makes any naive position-swap fail mid-transaction; wiping the set
|
|
then re-inserting at positions 0..N-1 keeps the operation atomic and obvious.
|
|
|
|
The same test_template may legitimately appear multiple times in a scenario
|
|
(chained operations), so we key on `(scenario_id, position)`, not
|
|
`(scenario_id, test_template_id)`.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import func, or_, select, text
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
_UNSET: Any = object()
|
|
|
|
from app.db.session import session_scope
|
|
from app.models.template import (
|
|
ScenarioTemplate,
|
|
ScenarioTemplateTest,
|
|
TestTemplate,
|
|
)
|
|
|
|
|
|
class ScenarioTemplateNotFound(Exception):
|
|
pass
|
|
|
|
|
|
class UnknownTestTemplate(Exception):
|
|
"""Raised when a scenario references a non-existent or soft-deleted test."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScenarioTestView:
|
|
position: int
|
|
test_template_id: uuid.UUID
|
|
test_template_name: str
|
|
test_template_deleted: bool
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScenarioTemplateView:
|
|
id: uuid.UUID
|
|
name: str
|
|
description: str | None
|
|
tests: list[ScenarioTestView]
|
|
tests_count: int
|
|
deleted_at: datetime | None
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
|
|
def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView:
|
|
test_ids = [link.test_template_id for link in sc.tests]
|
|
name_by_id: dict[uuid.UUID, tuple[str, bool]] = {}
|
|
if test_ids:
|
|
rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all()
|
|
for row in rows:
|
|
name_by_id[row.id] = (row.name, row.deleted_at is not None)
|
|
tests = [
|
|
ScenarioTestView(
|
|
position=link.position,
|
|
test_template_id=link.test_template_id,
|
|
test_template_name=name_by_id.get(link.test_template_id, ("<missing>", True))[0],
|
|
test_template_deleted=name_by_id.get(link.test_template_id, ("<missing>", True))[1],
|
|
)
|
|
for link in sc.tests
|
|
]
|
|
return ScenarioTemplateView(
|
|
id=sc.id,
|
|
name=sc.name,
|
|
description=sc.description,
|
|
tests=tests,
|
|
tests_count=len(tests),
|
|
deleted_at=sc.deleted_at,
|
|
created_at=sc.created_at,
|
|
updated_at=sc.updated_at,
|
|
)
|
|
|
|
|
|
def _base_query():
|
|
return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests))
|
|
|
|
|
|
def list_scenario_templates(
|
|
*,
|
|
q: str | None = None,
|
|
include_deleted: bool = False,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> tuple[list[ScenarioTemplateView], int]:
|
|
with session_scope() as s:
|
|
stmt = _base_query().order_by(ScenarioTemplate.name.asc())
|
|
count_stmt = select(func.count()).select_from(ScenarioTemplate)
|
|
if not include_deleted:
|
|
stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None))
|
|
count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None))
|
|
if q:
|
|
like = f"%{q.lower()}%"
|
|
cond = or_(
|
|
func.lower(ScenarioTemplate.name).like(like),
|
|
func.lower(ScenarioTemplate.description).like(like),
|
|
)
|
|
stmt = stmt.where(cond)
|
|
count_stmt = count_stmt.where(cond)
|
|
total = s.scalar(count_stmt) or 0
|
|
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
|
|
return [_to_view(s, sc) for sc in rows], int(total)
|
|
|
|
|
|
def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView:
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None:
|
|
raise ScenarioTemplateNotFound()
|
|
if sc.deleted_at is not None and not include_deleted:
|
|
raise ScenarioTemplateNotFound()
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None:
|
|
"""Reject unknown or soft-deleted test_template ids before persisting."""
|
|
if not ids:
|
|
return
|
|
found = s.execute(
|
|
select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids))
|
|
).all()
|
|
known = {row.id for row in found}
|
|
deleted = {row.id for row in found if row.deleted_at is not None}
|
|
missing = set(ids) - known
|
|
if missing:
|
|
raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}")
|
|
if deleted:
|
|
raise UnknownTestTemplate(
|
|
f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}"
|
|
)
|
|
|
|
|
|
def _opt_str(value: str | None) -> str | None:
|
|
if value is None:
|
|
return None
|
|
s = value.strip()
|
|
return s or None
|
|
|
|
|
|
def create_scenario_template(
|
|
*,
|
|
name: str,
|
|
description: str | None = None,
|
|
test_template_ids: list[uuid.UUID] | None = None,
|
|
) -> ScenarioTemplateView:
|
|
name_norm = (name or "").strip()
|
|
if not name_norm:
|
|
raise ValueError("name is required")
|
|
ids = list(test_template_ids or [])
|
|
with session_scope() as s:
|
|
_validate_test_ids(s, ids)
|
|
sc = ScenarioTemplate(
|
|
name=name_norm,
|
|
description=_opt_str(description),
|
|
)
|
|
s.add(sc)
|
|
s.flush()
|
|
for position, tid in enumerate(ids):
|
|
s.add(
|
|
ScenarioTemplateTest(
|
|
scenario_template_id=sc.id,
|
|
test_template_id=tid,
|
|
position=position,
|
|
)
|
|
)
|
|
s.flush()
|
|
s.refresh(sc)
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def update_scenario_template(
|
|
scenario_id: uuid.UUID,
|
|
*,
|
|
name: str | None = None,
|
|
description: Any = _UNSET,
|
|
) -> ScenarioTemplateView:
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None or sc.deleted_at is not None:
|
|
raise ScenarioTemplateNotFound()
|
|
if name is not None:
|
|
n = name.strip()
|
|
if not n:
|
|
raise ValueError("name cannot be empty")
|
|
sc.name = n
|
|
if description is not _UNSET:
|
|
sc.description = _opt_str(description)
|
|
s.flush()
|
|
s.refresh(sc)
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def set_scenario_tests(
|
|
scenario_id: uuid.UUID,
|
|
test_template_ids: list[uuid.UUID],
|
|
) -> ScenarioTemplateView:
|
|
"""Replace the entire ordered test list. `position` becomes the index.
|
|
|
|
Acquires a per-scenario advisory lock to serialise concurrent reorders.
|
|
Without it, two parallel `PUT /scenario-templates/{id}/tests` calls would
|
|
race on the wipe-then-insert sequence and deadlock on the UNIQUE(position)
|
|
constraint under READ COMMITTED. Mirrors the M4 pattern on /mitre/sync.
|
|
"""
|
|
with session_scope() as s:
|
|
# Lock keyed on the scenario UUID — different scenarios don't block
|
|
# each other. Single bigint form so we don't have to juggle int32
|
|
# signed ranges. blake2b is used instead of Python's built-in hash()
|
|
# because the latter is randomised per-process (PYTHONHASHSEED), so
|
|
# two gunicorn workers would compute different keys for the same
|
|
# scenario and the lock wouldn't serialise across them.
|
|
digest = hashlib.blake2b(scenario_id.bytes, digest_size=8).digest()
|
|
lock_key = int.from_bytes(digest, "big", signed=True)
|
|
s.execute(
|
|
text("SELECT pg_advisory_xact_lock(CAST(:key AS bigint))"),
|
|
{"key": lock_key},
|
|
)
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None or sc.deleted_at is not None:
|
|
raise ScenarioTemplateNotFound()
|
|
_validate_test_ids(s, test_template_ids)
|
|
# Wipe then re-insert. The UNIQUE(position) constraint forbids a
|
|
# naive UPDATE-swap; full-replace keeps the op atomic + readable.
|
|
for link in list(sc.tests):
|
|
s.delete(link)
|
|
s.flush()
|
|
for position, tid in enumerate(test_template_ids):
|
|
s.add(
|
|
ScenarioTemplateTest(
|
|
scenario_template_id=sc.id,
|
|
test_template_id=tid,
|
|
position=position,
|
|
)
|
|
)
|
|
s.flush()
|
|
s.refresh(sc)
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None:
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None or sc.deleted_at is not None:
|
|
raise ScenarioTemplateNotFound()
|
|
sc.deleted_at = datetime.now(tz=timezone.utc)
|