- Service `app/services/test_templates.py`: CRUD with MITRE tag resolution (kind, external_id) → polymorphic join, filters by tactic/technique/ subtechnique/opsec/tag, `_UNSET` sentinel for partial-update semantics. - Service `app/services/scenario_templates.py`: ordered test list, reorder via full-replace (atomic w.r.t. UNIQUE(position) constraint), soft-delete. - REST endpoints on /api/v1/test-templates and /scenario-templates with pydantic schemas + perm gating (test_template.* and scenario_template.*). - /diag/reset truncates the 4 new tables before MITRE (FK ordering). - 19 pytest covering CRUD, MITRE tag merge, soft-delete chaining, perm enforcement, and reorder atomicity. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
241 lines
7.6 KiB
Python
241 lines
7.6 KiB
Python
"""CRUD service for `scenario_templates` + their ordered test list.
|
|
|
|
Re-ordering is implemented as **full delete + re-insert** of the
|
|
`scenario_template_tests` rows. The UNIQUE (scenario_template_id, position)
|
|
constraint makes any naive position-swap fail mid-transaction; wiping the set
|
|
then re-inserting at positions 0..N-1 keeps the operation atomic and obvious.
|
|
|
|
The same test_template may legitimately appear multiple times in a scenario
|
|
(chained operations), so we key on `(scenario_id, position)`, not
|
|
`(scenario_id, test_template_id)`.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import func, or_, select
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
_UNSET: Any = object()
|
|
|
|
from app.db.session import session_scope
|
|
from app.models.template import (
|
|
ScenarioTemplate,
|
|
ScenarioTemplateTest,
|
|
TestTemplate,
|
|
)
|
|
|
|
|
|
class ScenarioTemplateNotFound(Exception):
|
|
pass
|
|
|
|
|
|
class UnknownTestTemplate(Exception):
|
|
"""Raised when a scenario references a non-existent or soft-deleted test."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScenarioTestView:
|
|
position: int
|
|
test_template_id: uuid.UUID
|
|
test_template_name: str
|
|
test_template_deleted: bool
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScenarioTemplateView:
|
|
id: uuid.UUID
|
|
name: str
|
|
description: str | None
|
|
tests: list[ScenarioTestView]
|
|
tests_count: int
|
|
deleted_at: datetime | None
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
|
|
def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView:
|
|
test_ids = [link.test_template_id for link in sc.tests]
|
|
name_by_id: dict[uuid.UUID, tuple[str, bool]] = {}
|
|
if test_ids:
|
|
rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all()
|
|
for row in rows:
|
|
name_by_id[row.id] = (row.name, row.deleted_at is not None)
|
|
tests = [
|
|
ScenarioTestView(
|
|
position=link.position,
|
|
test_template_id=link.test_template_id,
|
|
test_template_name=name_by_id.get(link.test_template_id, ("<missing>", True))[0],
|
|
test_template_deleted=name_by_id.get(link.test_template_id, ("<missing>", True))[1],
|
|
)
|
|
for link in sc.tests
|
|
]
|
|
return ScenarioTemplateView(
|
|
id=sc.id,
|
|
name=sc.name,
|
|
description=sc.description,
|
|
tests=tests,
|
|
tests_count=len(tests),
|
|
deleted_at=sc.deleted_at,
|
|
created_at=sc.created_at,
|
|
updated_at=sc.updated_at,
|
|
)
|
|
|
|
|
|
def _base_query():
|
|
return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests))
|
|
|
|
|
|
def list_scenario_templates(
|
|
*,
|
|
q: str | None = None,
|
|
include_deleted: bool = False,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> tuple[list[ScenarioTemplateView], int]:
|
|
with session_scope() as s:
|
|
stmt = _base_query().order_by(ScenarioTemplate.name.asc())
|
|
count_stmt = select(func.count()).select_from(ScenarioTemplate)
|
|
if not include_deleted:
|
|
stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None))
|
|
count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None))
|
|
if q:
|
|
like = f"%{q.lower()}%"
|
|
cond = or_(
|
|
func.lower(ScenarioTemplate.name).like(like),
|
|
func.lower(ScenarioTemplate.description).like(like),
|
|
)
|
|
stmt = stmt.where(cond)
|
|
count_stmt = count_stmt.where(cond)
|
|
total = s.scalar(count_stmt) or 0
|
|
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
|
|
return [_to_view(s, sc) for sc in rows], int(total)
|
|
|
|
|
|
def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView:
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None:
|
|
raise ScenarioTemplateNotFound()
|
|
if sc.deleted_at is not None and not include_deleted:
|
|
raise ScenarioTemplateNotFound()
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None:
|
|
"""Reject unknown or soft-deleted test_template ids before persisting."""
|
|
if not ids:
|
|
return
|
|
found = s.execute(
|
|
select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids))
|
|
).all()
|
|
known = {row.id for row in found}
|
|
deleted = {row.id for row in found if row.deleted_at is not None}
|
|
missing = set(ids) - known
|
|
if missing:
|
|
raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}")
|
|
if deleted:
|
|
raise UnknownTestTemplate(
|
|
f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}"
|
|
)
|
|
|
|
|
|
def _opt_str(value: str | None) -> str | None:
|
|
if value is None:
|
|
return None
|
|
s = value.strip()
|
|
return s or None
|
|
|
|
|
|
def create_scenario_template(
|
|
*,
|
|
name: str,
|
|
description: str | None = None,
|
|
test_template_ids: list[uuid.UUID] | None = None,
|
|
) -> ScenarioTemplateView:
|
|
name_norm = (name or "").strip()
|
|
if not name_norm:
|
|
raise ValueError("name is required")
|
|
ids = list(test_template_ids or [])
|
|
with session_scope() as s:
|
|
_validate_test_ids(s, ids)
|
|
sc = ScenarioTemplate(
|
|
name=name_norm,
|
|
description=_opt_str(description),
|
|
)
|
|
s.add(sc)
|
|
s.flush()
|
|
for position, tid in enumerate(ids):
|
|
s.add(
|
|
ScenarioTemplateTest(
|
|
scenario_template_id=sc.id,
|
|
test_template_id=tid,
|
|
position=position,
|
|
)
|
|
)
|
|
s.flush()
|
|
s.refresh(sc)
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def update_scenario_template(
|
|
scenario_id: uuid.UUID,
|
|
*,
|
|
name: str | None = None,
|
|
description: Any = _UNSET,
|
|
) -> ScenarioTemplateView:
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None or sc.deleted_at is not None:
|
|
raise ScenarioTemplateNotFound()
|
|
if name is not None:
|
|
n = name.strip()
|
|
if not n:
|
|
raise ValueError("name cannot be empty")
|
|
sc.name = n
|
|
if description is not _UNSET:
|
|
sc.description = _opt_str(description)
|
|
s.flush()
|
|
s.refresh(sc)
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def set_scenario_tests(
|
|
scenario_id: uuid.UUID,
|
|
test_template_ids: list[uuid.UUID],
|
|
) -> ScenarioTemplateView:
|
|
"""Replace the entire ordered test list. `position` becomes the index."""
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None or sc.deleted_at is not None:
|
|
raise ScenarioTemplateNotFound()
|
|
_validate_test_ids(s, test_template_ids)
|
|
# Wipe then re-insert. The UNIQUE(position) constraint forbids a
|
|
# naive UPDATE-swap; full-replace keeps the op atomic + readable.
|
|
for link in list(sc.tests):
|
|
s.delete(link)
|
|
s.flush()
|
|
for position, tid in enumerate(test_template_ids):
|
|
s.add(
|
|
ScenarioTemplateTest(
|
|
scenario_template_id=sc.id,
|
|
test_template_id=tid,
|
|
position=position,
|
|
)
|
|
)
|
|
s.flush()
|
|
s.refresh(sc)
|
|
return _to_view(s, sc)
|
|
|
|
|
|
def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None:
|
|
with session_scope() as s:
|
|
sc = s.get(ScenarioTemplate, scenario_id)
|
|
if sc is None or sc.deleted_at is not None:
|
|
raise ScenarioTemplateNotFound()
|
|
sc.deleted_at = datetime.now(tz=timezone.utc)
|