feat(m5): test_template + scenario_template CRUD with MITRE tags and ordered tests
- Service `app/services/test_templates.py`: CRUD with MITRE tag resolution (kind, external_id) → polymorphic join, filters by tactic/technique/ subtechnique/opsec/tag, `_UNSET` sentinel for partial-update semantics. - Service `app/services/scenario_templates.py`: ordered test list, reorder via full-replace (atomic w.r.t. UNIQUE(position) constraint), soft-delete. - REST endpoints on /api/v1/test-templates and /scenario-templates with pydantic schemas + perm gating (test_template.* and scenario_template.*). - /diag/reset truncates the 4 new tables before MITRE (FK ordering). - 19 pytest covering CRUD, MITRE tag merge, soft-delete chaining, perm enforcement, and reorder atomicity. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
240
backend/app/services/scenario_templates.py
Normal file
240
backend/app/services/scenario_templates.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""CRUD service for `scenario_templates` + their ordered test list.
|
||||
|
||||
Re-ordering is implemented as **full delete + re-insert** of the
|
||||
`scenario_template_tests` rows. The UNIQUE (scenario_template_id, position)
|
||||
constraint makes any naive position-swap fail mid-transaction; wiping the set
|
||||
then re-inserting at positions 0..N-1 keeps the operation atomic and obvious.
|
||||
|
||||
The same test_template may legitimately appear multiple times in a scenario
|
||||
(chained operations), so we key on `(scenario_id, position)`, not
|
||||
`(scenario_id, test_template_id)`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
_UNSET: Any = object()
|
||||
|
||||
from app.db.session import session_scope
|
||||
from app.models.template import (
|
||||
ScenarioTemplate,
|
||||
ScenarioTemplateTest,
|
||||
TestTemplate,
|
||||
)
|
||||
|
||||
|
||||
class ScenarioTemplateNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownTestTemplate(Exception):
|
||||
"""Raised when a scenario references a non-existent or soft-deleted test."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScenarioTestView:
|
||||
position: int
|
||||
test_template_id: uuid.UUID
|
||||
test_template_name: str
|
||||
test_template_deleted: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScenarioTemplateView:
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
tests: list[ScenarioTestView]
|
||||
tests_count: int
|
||||
deleted_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _to_view(s: Session, sc: ScenarioTemplate) -> ScenarioTemplateView:
|
||||
test_ids = [link.test_template_id for link in sc.tests]
|
||||
name_by_id: dict[uuid.UUID, tuple[str, bool]] = {}
|
||||
if test_ids:
|
||||
rows = s.scalars(select(TestTemplate).where(TestTemplate.id.in_(test_ids))).all()
|
||||
for row in rows:
|
||||
name_by_id[row.id] = (row.name, row.deleted_at is not None)
|
||||
tests = [
|
||||
ScenarioTestView(
|
||||
position=link.position,
|
||||
test_template_id=link.test_template_id,
|
||||
test_template_name=name_by_id.get(link.test_template_id, ("<missing>", True))[0],
|
||||
test_template_deleted=name_by_id.get(link.test_template_id, ("<missing>", True))[1],
|
||||
)
|
||||
for link in sc.tests
|
||||
]
|
||||
return ScenarioTemplateView(
|
||||
id=sc.id,
|
||||
name=sc.name,
|
||||
description=sc.description,
|
||||
tests=tests,
|
||||
tests_count=len(tests),
|
||||
deleted_at=sc.deleted_at,
|
||||
created_at=sc.created_at,
|
||||
updated_at=sc.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _base_query():
|
||||
return select(ScenarioTemplate).options(selectinload(ScenarioTemplate.tests))
|
||||
|
||||
|
||||
def list_scenario_templates(
|
||||
*,
|
||||
q: str | None = None,
|
||||
include_deleted: bool = False,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ScenarioTemplateView], int]:
|
||||
with session_scope() as s:
|
||||
stmt = _base_query().order_by(ScenarioTemplate.name.asc())
|
||||
count_stmt = select(func.count()).select_from(ScenarioTemplate)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ScenarioTemplate.deleted_at.is_(None))
|
||||
count_stmt = count_stmt.where(ScenarioTemplate.deleted_at.is_(None))
|
||||
if q:
|
||||
like = f"%{q.lower()}%"
|
||||
cond = or_(
|
||||
func.lower(ScenarioTemplate.name).like(like),
|
||||
func.lower(ScenarioTemplate.description).like(like),
|
||||
)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
total = s.scalar(count_stmt) or 0
|
||||
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
|
||||
return [_to_view(s, sc) for sc in rows], int(total)
|
||||
|
||||
|
||||
def get_scenario_template(scenario_id: uuid.UUID, *, include_deleted: bool = False) -> ScenarioTemplateView:
|
||||
with session_scope() as s:
|
||||
sc = s.get(ScenarioTemplate, scenario_id)
|
||||
if sc is None:
|
||||
raise ScenarioTemplateNotFound()
|
||||
if sc.deleted_at is not None and not include_deleted:
|
||||
raise ScenarioTemplateNotFound()
|
||||
return _to_view(s, sc)
|
||||
|
||||
|
||||
def _validate_test_ids(s: Session, ids: list[uuid.UUID]) -> None:
|
||||
"""Reject unknown or soft-deleted test_template ids before persisting."""
|
||||
if not ids:
|
||||
return
|
||||
found = s.execute(
|
||||
select(TestTemplate.id, TestTemplate.deleted_at).where(TestTemplate.id.in_(ids))
|
||||
).all()
|
||||
known = {row.id for row in found}
|
||||
deleted = {row.id for row in found if row.deleted_at is not None}
|
||||
missing = set(ids) - known
|
||||
if missing:
|
||||
raise UnknownTestTemplate(f"unknown test_template ids: {sorted(str(m) for m in missing)}")
|
||||
if deleted:
|
||||
raise UnknownTestTemplate(
|
||||
f"cannot reference soft-deleted test_template ids: {sorted(str(d) for d in deleted)}"
|
||||
)
|
||||
|
||||
|
||||
def _opt_str(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
s = value.strip()
|
||||
return s or None
|
||||
|
||||
|
||||
def create_scenario_template(
|
||||
*,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
test_template_ids: list[uuid.UUID] | None = None,
|
||||
) -> ScenarioTemplateView:
|
||||
name_norm = (name or "").strip()
|
||||
if not name_norm:
|
||||
raise ValueError("name is required")
|
||||
ids = list(test_template_ids or [])
|
||||
with session_scope() as s:
|
||||
_validate_test_ids(s, ids)
|
||||
sc = ScenarioTemplate(
|
||||
name=name_norm,
|
||||
description=_opt_str(description),
|
||||
)
|
||||
s.add(sc)
|
||||
s.flush()
|
||||
for position, tid in enumerate(ids):
|
||||
s.add(
|
||||
ScenarioTemplateTest(
|
||||
scenario_template_id=sc.id,
|
||||
test_template_id=tid,
|
||||
position=position,
|
||||
)
|
||||
)
|
||||
s.flush()
|
||||
s.refresh(sc)
|
||||
return _to_view(s, sc)
|
||||
|
||||
|
||||
def update_scenario_template(
|
||||
scenario_id: uuid.UUID,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: Any = _UNSET,
|
||||
) -> ScenarioTemplateView:
|
||||
with session_scope() as s:
|
||||
sc = s.get(ScenarioTemplate, scenario_id)
|
||||
if sc is None or sc.deleted_at is not None:
|
||||
raise ScenarioTemplateNotFound()
|
||||
if name is not None:
|
||||
n = name.strip()
|
||||
if not n:
|
||||
raise ValueError("name cannot be empty")
|
||||
sc.name = n
|
||||
if description is not _UNSET:
|
||||
sc.description = _opt_str(description)
|
||||
s.flush()
|
||||
s.refresh(sc)
|
||||
return _to_view(s, sc)
|
||||
|
||||
|
||||
def set_scenario_tests(
|
||||
scenario_id: uuid.UUID,
|
||||
test_template_ids: list[uuid.UUID],
|
||||
) -> ScenarioTemplateView:
|
||||
"""Replace the entire ordered test list. `position` becomes the index."""
|
||||
with session_scope() as s:
|
||||
sc = s.get(ScenarioTemplate, scenario_id)
|
||||
if sc is None or sc.deleted_at is not None:
|
||||
raise ScenarioTemplateNotFound()
|
||||
_validate_test_ids(s, test_template_ids)
|
||||
# Wipe then re-insert. The UNIQUE(position) constraint forbids a
|
||||
# naive UPDATE-swap; full-replace keeps the op atomic + readable.
|
||||
for link in list(sc.tests):
|
||||
s.delete(link)
|
||||
s.flush()
|
||||
for position, tid in enumerate(test_template_ids):
|
||||
s.add(
|
||||
ScenarioTemplateTest(
|
||||
scenario_template_id=sc.id,
|
||||
test_template_id=tid,
|
||||
position=position,
|
||||
)
|
||||
)
|
||||
s.flush()
|
||||
s.refresh(sc)
|
||||
return _to_view(s, sc)
|
||||
|
||||
|
||||
def soft_delete_scenario_template(scenario_id: uuid.UUID) -> None:
|
||||
with session_scope() as s:
|
||||
sc = s.get(ScenarioTemplate, scenario_id)
|
||||
if sc is None or sc.deleted_at is not None:
|
||||
raise ScenarioTemplateNotFound()
|
||||
sc.deleted_at = datetime.now(tz=timezone.utc)
|
||||
395
backend/app/services/test_templates.py
Normal file
395
backend/app/services/test_templates.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""CRUD service for `test_templates` + their MITRE tags.
|
||||
|
||||
The MITRE tag set is **fully replaced** on every update — partial mutation of
|
||||
the join rows would force the API client to track tag UUIDs they never created.
|
||||
The polymorphic join (one of `tactic_id` / `technique_id` / `subtechnique_id`
|
||||
populated) is owned here: callers pass `(kind, external_id)` tuples and we
|
||||
resolve them to the matching MITRE row.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Iterable
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
_UNSET: Any = object()
|
||||
|
||||
from app.db.session import session_scope
|
||||
from app.db.types import MITRE_KINDS, OPSEC_LEVELS
|
||||
from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique
|
||||
from app.models.template import TestTemplate, TestTemplateMitreTag
|
||||
|
||||
|
||||
class TestTemplateNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownMitreTag(Exception):
|
||||
"""Raised when an (kind, external_id) tuple doesn't resolve to a known MITRE row."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MitreTagRef:
|
||||
"""Inbound MITRE tag reference. `external_id` is the ATT&CK identifier
|
||||
(TA…/T…/T….…) — we resolve it server-side, the client never sees UUIDs.
|
||||
"""
|
||||
|
||||
kind: str # "tactic" | "technique" | "subtechnique"
|
||||
external_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MitreTagView:
|
||||
kind: str
|
||||
external_id: str
|
||||
name: str
|
||||
url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TestTemplateView:
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
objective: str | None
|
||||
procedure_md: str | None
|
||||
prerequisites_md: str | None
|
||||
expected_result_red_md: str | None
|
||||
expected_detection_blue_md: str | None
|
||||
opsec_level: str
|
||||
tags: list[str]
|
||||
expected_iocs: list[str]
|
||||
mitre_tags: list[MitreTagView]
|
||||
deleted_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _validate_opsec(value: str) -> str:
|
||||
if value not in OPSEC_LEVELS:
|
||||
raise ValueError(f"opsec_level must be one of {OPSEC_LEVELS}")
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_string_list(values: Iterable[str] | None) -> list[str]:
|
||||
if not values:
|
||||
return []
|
||||
seen: set[str] = set()
|
||||
out: list[str] = []
|
||||
for raw in values:
|
||||
if not isinstance(raw, str):
|
||||
raise ValueError("list items must be strings")
|
||||
v = raw.strip()
|
||||
if not v or v in seen:
|
||||
continue
|
||||
seen.add(v)
|
||||
out.append(v)
|
||||
return out
|
||||
|
||||
|
||||
def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplateMitreTag]:
|
||||
"""Translate `(kind, external_id)` pairs into half-populated join rows.
|
||||
|
||||
Validates that:
|
||||
- `kind` is one of the supported values
|
||||
- each external_id resolves to an existing MITRE row
|
||||
- the combination is unique inside the payload (de-duped silently — same
|
||||
tag twice is a no-op, not an error)
|
||||
"""
|
||||
if not refs:
|
||||
return []
|
||||
# Dedupe input
|
||||
deduped: dict[tuple[str, str], MitreTagRef] = {}
|
||||
for ref in refs:
|
||||
if ref.kind not in MITRE_KINDS:
|
||||
raise ValueError(f"mitre tag kind must be one of {MITRE_KINDS}")
|
||||
if not ref.external_id:
|
||||
raise ValueError("mitre tag external_id is required")
|
||||
deduped[(ref.kind, ref.external_id)] = ref
|
||||
|
||||
tactic_ids = {r.external_id for r in deduped.values() if r.kind == "tactic"}
|
||||
technique_ids = {r.external_id for r in deduped.values() if r.kind == "technique"}
|
||||
subtechnique_ids = {r.external_id for r in deduped.values() if r.kind == "subtechnique"}
|
||||
|
||||
tactic_map = {
|
||||
t.external_id: t.id
|
||||
for t in s.scalars(select(MitreTactic).where(MitreTactic.external_id.in_(tactic_ids))).all()
|
||||
}
|
||||
technique_map = {
|
||||
t.external_id: t.id
|
||||
for t in s.scalars(select(MitreTechnique).where(MitreTechnique.external_id.in_(technique_ids))).all()
|
||||
}
|
||||
subtechnique_map = {
|
||||
sb.external_id: sb.id
|
||||
for sb in s.scalars(
|
||||
select(MitreSubtechnique).where(MitreSubtechnique.external_id.in_(subtechnique_ids))
|
||||
).all()
|
||||
}
|
||||
|
||||
rows: list[TestTemplateMitreTag] = []
|
||||
missing: list[tuple[str, str]] = []
|
||||
for ref in deduped.values():
|
||||
if ref.kind == "tactic":
|
||||
mid = tactic_map.get(ref.external_id)
|
||||
if mid is None:
|
||||
missing.append((ref.kind, ref.external_id))
|
||||
continue
|
||||
rows.append(TestTemplateMitreTag(mitre_kind="tactic", tactic_id=mid))
|
||||
elif ref.kind == "technique":
|
||||
mid = technique_map.get(ref.external_id)
|
||||
if mid is None:
|
||||
missing.append((ref.kind, ref.external_id))
|
||||
continue
|
||||
rows.append(TestTemplateMitreTag(mitre_kind="technique", technique_id=mid))
|
||||
else:
|
||||
mid = subtechnique_map.get(ref.external_id)
|
||||
if mid is None:
|
||||
missing.append((ref.kind, ref.external_id))
|
||||
continue
|
||||
rows.append(TestTemplateMitreTag(mitre_kind="subtechnique", subtechnique_id=mid))
|
||||
if missing:
|
||||
raise UnknownMitreTag(f"unknown MITRE tags: {sorted(missing)}")
|
||||
return rows
|
||||
|
||||
|
||||
def _to_view(s: Session, t: TestTemplate) -> TestTemplateView:
|
||||
tag_views: list[MitreTagView] = []
|
||||
for tag in t.mitre_tags:
|
||||
if tag.mitre_kind == "tactic" and tag.tactic_id is not None:
|
||||
row = s.get(MitreTactic, tag.tactic_id)
|
||||
if row is not None:
|
||||
tag_views.append(MitreTagView(kind="tactic", external_id=row.external_id, name=row.name, url=row.url))
|
||||
elif tag.mitre_kind == "technique" and tag.technique_id is not None:
|
||||
row = s.get(MitreTechnique, tag.technique_id)
|
||||
if row is not None:
|
||||
tag_views.append(MitreTagView(kind="technique", external_id=row.external_id, name=row.name, url=row.url))
|
||||
elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None:
|
||||
row = s.get(MitreSubtechnique, tag.subtechnique_id)
|
||||
if row is not None:
|
||||
tag_views.append(
|
||||
MitreTagView(kind="subtechnique", external_id=row.external_id, name=row.name, url=row.url)
|
||||
)
|
||||
tag_views.sort(key=lambda v: (v.kind, v.external_id))
|
||||
return TestTemplateView(
|
||||
id=t.id,
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
objective=t.objective,
|
||||
procedure_md=t.procedure_md,
|
||||
prerequisites_md=t.prerequisites_md,
|
||||
expected_result_red_md=t.expected_result_red_md,
|
||||
expected_detection_blue_md=t.expected_detection_blue_md,
|
||||
opsec_level=t.opsec_level,
|
||||
tags=list(t.tags or []),
|
||||
expected_iocs=list(t.expected_iocs or []),
|
||||
mitre_tags=tag_views,
|
||||
deleted_at=t.deleted_at,
|
||||
created_at=t.created_at,
|
||||
updated_at=t.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _base_query():
|
||||
return select(TestTemplate).options(selectinload(TestTemplate.mitre_tags))
|
||||
|
||||
|
||||
def list_test_templates(
|
||||
*,
|
||||
q: str | None = None,
|
||||
tactic: str | None = None, # external_id like "TA0006"
|
||||
technique: str | None = None,
|
||||
subtechnique: str | None = None,
|
||||
opsec_level: str | None = None,
|
||||
tag: str | None = None,
|
||||
include_deleted: bool = False,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[TestTemplateView], int]:
|
||||
with session_scope() as s:
|
||||
stmt = _base_query().order_by(TestTemplate.name.asc())
|
||||
count_stmt = select(func.count()).select_from(TestTemplate)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(TestTemplate.deleted_at.is_(None))
|
||||
count_stmt = count_stmt.where(TestTemplate.deleted_at.is_(None))
|
||||
if q:
|
||||
like = f"%{q.lower()}%"
|
||||
cond = or_(
|
||||
func.lower(TestTemplate.name).like(like),
|
||||
func.lower(TestTemplate.description).like(like),
|
||||
)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
if opsec_level:
|
||||
_validate_opsec(opsec_level)
|
||||
stmt = stmt.where(TestTemplate.opsec_level == opsec_level)
|
||||
count_stmt = count_stmt.where(TestTemplate.opsec_level == opsec_level)
|
||||
if tag:
|
||||
stmt = stmt.where(TestTemplate.tags.any(tag))
|
||||
count_stmt = count_stmt.where(TestTemplate.tags.any(tag))
|
||||
|
||||
# MITRE facet: resolve external_id → uuid then filter via join subquery.
|
||||
if tactic or technique or subtechnique:
|
||||
tag_ids: list[uuid.UUID] = []
|
||||
if tactic:
|
||||
tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic))
|
||||
if tac is None:
|
||||
return [], 0
|
||||
tag_ids.append(tac.id)
|
||||
if technique:
|
||||
tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique))
|
||||
if tech is None:
|
||||
return [], 0
|
||||
tag_ids.append(tech.id)
|
||||
if subtechnique:
|
||||
sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique))
|
||||
if sub is None:
|
||||
return [], 0
|
||||
tag_ids.append(sub.id)
|
||||
sub_q = (
|
||||
select(TestTemplateMitreTag.test_template_id)
|
||||
.where(
|
||||
or_(
|
||||
TestTemplateMitreTag.tactic_id.in_(tag_ids),
|
||||
TestTemplateMitreTag.technique_id.in_(tag_ids),
|
||||
TestTemplateMitreTag.subtechnique_id.in_(tag_ids),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
stmt = stmt.where(TestTemplate.id.in_(sub_q))
|
||||
count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q))
|
||||
|
||||
total = s.scalar(count_stmt) or 0
|
||||
rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all()
|
||||
return [_to_view(s, t) for t in rows], int(total)
|
||||
|
||||
|
||||
def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView:
|
||||
with session_scope() as s:
|
||||
t = s.get(TestTemplate, template_id)
|
||||
if t is None:
|
||||
raise TestTemplateNotFound()
|
||||
if t.deleted_at is not None and not include_deleted:
|
||||
raise TestTemplateNotFound()
|
||||
return _to_view(s, t)
|
||||
|
||||
|
||||
def create_test_template(
|
||||
*,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
objective: str | None = None,
|
||||
procedure_md: str | None = None,
|
||||
prerequisites_md: str | None = None,
|
||||
expected_result_red_md: str | None = None,
|
||||
expected_detection_blue_md: str | None = None,
|
||||
opsec_level: str = "medium",
|
||||
tags: list[str] | None = None,
|
||||
expected_iocs: list[str] | None = None,
|
||||
mitre_tags: list[MitreTagRef] | None = None,
|
||||
) -> TestTemplateView:
|
||||
name_norm = (name or "").strip()
|
||||
if not name_norm:
|
||||
raise ValueError("name is required")
|
||||
_validate_opsec(opsec_level)
|
||||
norm_tags = _normalize_string_list(tags)
|
||||
norm_iocs = _normalize_string_list(expected_iocs)
|
||||
with session_scope() as s:
|
||||
t = TestTemplate(
|
||||
name=name_norm,
|
||||
description=_opt_str(description),
|
||||
objective=_opt_str(objective),
|
||||
procedure_md=procedure_md or None,
|
||||
prerequisites_md=prerequisites_md or None,
|
||||
expected_result_red_md=expected_result_red_md or None,
|
||||
expected_detection_blue_md=expected_detection_blue_md or None,
|
||||
opsec_level=opsec_level,
|
||||
tags=norm_tags,
|
||||
expected_iocs=norm_iocs,
|
||||
)
|
||||
s.add(t)
|
||||
s.flush()
|
||||
if mitre_tags:
|
||||
rows = _resolve_mitre_refs(s, mitre_tags)
|
||||
for row in rows:
|
||||
row.test_template_id = t.id
|
||||
s.add(row)
|
||||
s.flush()
|
||||
s.refresh(t)
|
||||
return _to_view(s, t)
|
||||
|
||||
|
||||
def _opt_str(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
s = value.strip()
|
||||
return s or None
|
||||
|
||||
|
||||
def update_test_template(
|
||||
template_id: uuid.UUID,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: Any = _UNSET,
|
||||
objective: Any = _UNSET,
|
||||
procedure_md: Any = _UNSET,
|
||||
prerequisites_md: Any = _UNSET,
|
||||
expected_result_red_md: Any = _UNSET,
|
||||
expected_detection_blue_md: Any = _UNSET,
|
||||
opsec_level: str | None = None,
|
||||
tags: Any = _UNSET,
|
||||
expected_iocs: Any = _UNSET,
|
||||
mitre_tags: Any = _UNSET,
|
||||
) -> TestTemplateView:
|
||||
with session_scope() as s:
|
||||
t = s.get(TestTemplate, template_id)
|
||||
if t is None or t.deleted_at is not None:
|
||||
raise TestTemplateNotFound()
|
||||
if name is not None:
|
||||
n = name.strip()
|
||||
if not n:
|
||||
raise ValueError("name cannot be empty")
|
||||
t.name = n
|
||||
if description is not _UNSET:
|
||||
t.description = _opt_str(description)
|
||||
if objective is not _UNSET:
|
||||
t.objective = _opt_str(objective)
|
||||
if procedure_md is not _UNSET:
|
||||
t.procedure_md = procedure_md or None
|
||||
if prerequisites_md is not _UNSET:
|
||||
t.prerequisites_md = prerequisites_md or None
|
||||
if expected_result_red_md is not _UNSET:
|
||||
t.expected_result_red_md = expected_result_red_md or None
|
||||
if expected_detection_blue_md is not _UNSET:
|
||||
t.expected_detection_blue_md = expected_detection_blue_md or None
|
||||
if opsec_level is not None:
|
||||
_validate_opsec(opsec_level)
|
||||
t.opsec_level = opsec_level
|
||||
if tags is not _UNSET:
|
||||
t.tags = _normalize_string_list(tags)
|
||||
if expected_iocs is not _UNSET:
|
||||
t.expected_iocs = _normalize_string_list(expected_iocs)
|
||||
if mitre_tags is not _UNSET:
|
||||
for row in list(t.mitre_tags):
|
||||
s.delete(row)
|
||||
s.flush()
|
||||
rows = _resolve_mitre_refs(s, list(mitre_tags or []))
|
||||
for row in rows:
|
||||
row.test_template_id = t.id
|
||||
s.add(row)
|
||||
s.flush()
|
||||
s.refresh(t)
|
||||
return _to_view(s, t)
|
||||
|
||||
|
||||
def soft_delete_test_template(template_id: uuid.UUID) -> None:
|
||||
with session_scope() as s:
|
||||
t = s.get(TestTemplate, template_id)
|
||||
if t is None or t.deleted_at is not None:
|
||||
raise TestTemplateNotFound()
|
||||
t.deleted_at = datetime.now(tz=timezone.utc)
|
||||
Reference in New Issue
Block a user