"""CRUD service for `scenario_templates` + their ordered test list. Re-ordering is implemented as **full delete + re-insert** of the `scenario_template_tests` rows. The UNIQUE (scenario_template_id, position) constraint makes any naive position-swap fail mid-transaction; wiping the set then re-inserting at positions 0..N-1 keeps the operation atomic and obvious. The same test_template may legitimately appear multiple times in a scenario (chained operations), so we key on `(scenario_id, position)`, not `(scenario_id, test_template_id)`. """ from __future__ import annotations import hashlib import uuid from dataclasses import dataclass from datetime import datetime, timezone from typing import Any from sqlalchemy import func, or_, select, text from sqlalchemy.orm import Session, selectinload _UNSET: Any = object() from app.db.session import session_scope from app.models.template import ( ScenarioTemplate, ScenarioTemplateTest, TestTemplate, ) class ScenarioTemplateNotFound(Exception): pass class UnknownTestTemplate(Exception): """Raised when a scenario references a non-existent or soft-deleted test.""" @dataclass(frozen=True) class ScenarioTestView: position: int test_template_id: uuid.UUID test_template_name: str test_template_deleted: bool @dataclass(frozen=True) class ScenarioTemplateView: id: uuid.UUID name: str description: str | None tests: list[ScenarioTestView] tests_count: int deleted_at: datetime | None created_at: datetime updated_at: datetime def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView: test_ids = [link.test_template_id for link in sc.tests] name_by_id: dict[uuid.UUID, tuple[str, bool]] = {} if test_ids: rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all() for row in rows: name_by_id[row.id] = (row.name, row.deleted_at is not None) tests = [ ScenarioTestView( position=link.position, test_template_id=link.test_template_id, test_template_name=name_by_id.get(link.test_template_id, ("", True))[0], test_template_deleted=name_by_id.get(link.test_template_id, ("", True))[1], ) for link in sc.tests ] return ScenarioTemplateView( id=sc.id, name=sc.name, description=sc.description, tests=tests, tests_count=len(tests), deleted_at=sc.deleted_at, created_at=sc.created_at, updated_at=sc.updated_at, ) def _base_query(): return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests)) def list_scenario_templates( *, q: str | None = None, include_deleted: bool = False, limit: int = 100, offset: int = 0, ) -> tuple[list[ScenarioTemplateView], int]: with session_scope() as s: stmt = _base_query().order_by(ScenarioTemplate.name.asc()) count_stmt = select(func.count()).select_from(ScenarioTemplate) if not include_deleted: stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None)) count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None)) if q: like = f"%{q.lower()}%" cond = or_( func.lower(ScenarioTemplate.name).like(like), func.lower(ScenarioTemplate.description).like(like), ) stmt = stmt.where(cond) count_stmt = count_stmt.where(cond) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() return [_to_view(s, sc) for sc in rows], int(total) def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView: with session_scope() as s: sc = s.get(ScenarioTemplate, scenario_id) if sc is None: raise ScenarioTemplateNotFound() if sc.deleted_at is not None and not include_deleted: raise ScenarioTemplateNotFound() return _to_view(s, sc) def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None: """Reject unknown or soft-deleted test_template ids before persisting.""" if not ids: return found = s.execute( select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids)) ).all() known = {row.id for row in found} deleted = {row.id for row in found if row.deleted_at is not None} missing = set(ids) - known if missing: raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}") if deleted: raise UnknownTestTemplate( f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}" ) def _opt_str(value: str | None) -> str | None: if value is None: return None s = value.strip() return s or None def create_scenario_template( *, name: str, description: str | None = None, test_template_ids: list[uuid.UUID] | None = None, ) -> ScenarioTemplateView: name_norm = (name or "").strip() if not name_norm: raise ValueError("name is required") ids = list(test_template_ids or []) with session_scope() as s: _validate_test_ids(s, ids) sc = ScenarioTemplate( name=name_norm, description=_opt_str(description), ) s.add(sc) s.flush() for position, tid in enumerate(ids): s.add( ScenarioTemplateTest( scenario_template_id=sc.id, test_template_id=tid, position=position, ) ) s.flush() s.refresh(sc) return _to_view(s, sc) def update_scenario_template( scenario_id: uuid.UUID, *, name: str | None = None, description: Any = _UNSET, ) -> ScenarioTemplateView: with session_scope() as s: sc = s.get(ScenarioTemplate, scenario_id) if sc is None or sc.deleted_at is not None: raise ScenarioTemplateNotFound() if name is not None: n = name.strip() if not n: raise ValueError("name cannot be empty") sc.name = n if description is not _UNSET: sc.description = _opt_str(description) s.flush() s.refresh(sc) return _to_view(s, sc) def set_scenario_tests( scenario_id: uuid.UUID, test_template_ids: list[uuid.UUID], ) -> ScenarioTemplateView: """Replace the entire ordered test list. `position` becomes the index. Acquires a per-scenario advisory lock to serialise concurrent reorders. Without it, two parallel `PUT /scenario-templates/{id}/tests` calls would race on the wipe-then-insert sequence and deadlock on the UNIQUE(position) constraint under READ COMMITTED. Mirrors the M4 pattern on /mitre/sync. """ with session_scope() as s: # Lock keyed on the scenario UUID — different scenarios don't block # each other. Single bigint form so we don't have to juggle int32 # signed ranges. blake2b is used instead of Python's built-in hash() # because the latter is randomised per-process (PYTHONHASHSEED), so # two gunicorn workers would compute different keys for the same # scenario and the lock wouldn't serialise across them. digest = hashlib.blake2b(scenario_id.bytes, digest_size=8).digest() lock_key = int.from_bytes(digest, "big", signed=True) s.execute( text("SELECT pg_advisory_xact_lock(CAST(:key AS bigint))"), {"key": lock_key}, ) sc = s.get(ScenarioTemplate, scenario_id) if sc is None or sc.deleted_at is not None: raise ScenarioTemplateNotFound() _validate_test_ids(s, test_template_ids) # Wipe then re-insert. The UNIQUE(position) constraint forbids a # naive UPDATE-swap; full-replace keeps the op atomic + readable. for link in list(sc.tests): s.delete(link) s.flush() for position, tid in enumerate(test_template_ids): s.add( ScenarioTemplateTest( scenario_template_id=sc.id, test_template_id=tid, position=position, ) ) s.flush() s.refresh(sc) return _to_view(s, sc) def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None: with session_scope() as s: sc = s.get(ScenarioTemplate, scenario_id) if sc is None or sc.deleted_at is not None: raise ScenarioTemplateNotFound() sc.deleted_at = datetime.now(tz=timezone.utc)