"""CRUD service for `test_templates` + their MITRE tags. The MITRE tag set is **fully replaced** on every update — partial mutation of the join rows would force the API client to track tag UUIDs they never created. The polymorphic join (one of `tactic_id` / `technique_id` / `subtechnique_id` populated) is owned here: callers pass `(kind, external_id)` tuples and we resolve them to the matching MITRE row. """ from __future__ import annotations import uuid from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Iterable from sqlalchemy import func, or_, select from sqlalchemy.orm import Session, selectinload _UNSET: Any = object() from app.db.session import session_scope from app.db.types import MITRE_KINDS, OPSEC_LEVELS from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique from app.models.template import TestTemplate, TestTemplateMitreTag class TestTemplateNotFound(Exception): pass class UnknownMitreTag(Exception): """Raised when an (kind, external_id) tuple doesn't resolve to a known MITRE row.""" @dataclass(frozen=True) class MitreTagRef: """Inbound MITRE tag reference. `external_id` is the ATT&CK identifier (TA…/T…/T….…) — we resolve it server-side, the client never sees UUIDs. """ kind: str # "tactic" | "technique" | "subtechnique" external_id: str @dataclass(frozen=True) class MitreTagView: kind: str external_id: str name: str url: str | None @dataclass(frozen=True) class TestTemplateView: id: uuid.UUID name: str description: str | None objective: str | None procedure_md: str | None prerequisites_md: str | None expected_result_red_md: str | None expected_detection_blue_md: str | None opsec_level: str tags: list[str] expected_iocs: list[str] mitre_tags: list[MitreTagView] deleted_at: datetime | None created_at: datetime updated_at: datetime def _validate_opsec(value: str) -> str: if value not in OPSEC_LEVELS: raise ValueError(f"opsec_level must be one of {OPSEC_LEVELS}") return value def _normalize_string_list(values: Iterable[str] | None) -> list[str]: if not values: return [] seen: set[str] = set() out: list[str] = [] for raw in values: if not isinstance(raw, str): raise ValueError("list items must be strings") v = raw.strip() if not v or v in seen: continue seen.add(v) out.append(v) return out def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplateMitreTag]: """Translate `(kind, external_id)` pairs into half-populated join rows. Validates that: - `kind` is one of the supported values - each external_id resolves to an existing MITRE row - the combination is unique inside the payload (de-duped silently — same tag twice is a no-op, not an error) """ if not refs: return [] # Dedupe input deduped: dict[tuple[str, str], MitreTagRef] = {} for ref in refs: if ref.kind not in MITRE_KINDS: raise ValueError(f"mitre tag kind must be one of {MITRE_KINDS}") if not ref.external_id: raise ValueError("mitre tag external_id is required") deduped[(ref.kind, ref.external_id)] = ref tactic_ids = {r.external_id for r in deduped.values() if r.kind == "tactic"} technique_ids = {r.external_id for r in deduped.values() if r.kind == "technique"} subtechnique_ids = {r.external_id for r in deduped.values() if r.kind == "subtechnique"} tactic_map = { t.external_id: t.id for t in s.scalars(select(MitreTactic).where(MitreTactic.external_id.in_(tactic_ids))).all() } technique_map = { t.external_id: t.id for t in s.scalars(select(MitreTechnique).where(MitreTechnique.external_id.in_(technique_ids))).all() } subtechnique_map = { sb.external_id: sb.id for sb in s.scalars( select(MitreSubtechnique).where(MitreSubtechnique.external_id.in_(subtechnique_ids)) ).all() } rows: list[TestTemplateMitreTag] = [] missing: list[tuple[str, str]] = [] for ref in deduped.values(): if ref.kind == "tactic": mid = tactic_map.get(ref.external_id) if mid is None: missing.append((ref.kind, ref.external_id)) continue rows.append(TestTemplateMitreTag(mitre_kind="tactic", tactic_id=mid)) elif ref.kind == "technique": mid = technique_map.get(ref.external_id) if mid is None: missing.append((ref.kind, ref.external_id)) continue rows.append(TestTemplateMitreTag(mitre_kind="technique", technique_id=mid)) else: mid = subtechnique_map.get(ref.external_id) if mid is None: missing.append((ref.kind, ref.external_id)) continue rows.append(TestTemplateMitreTag(mitre_kind="subtechnique", subtechnique_id=mid)) if missing: raise UnknownMitreTag(f"unknown MITRE tags: {sorted(missing)}") return rows def _resolve_mitre_views(s: Session, tags: list[TestTemplateMitreTag]) -> list[MitreTagView]: """Batch-resolve polymorphic MITRE FKs into MitreTagViews in 3 queries total — one per kind — regardless of how many tags or templates the caller is rendering. """ tactic_ids = {t.tactic_id for t in tags if t.mitre_kind == "tactic" and t.tactic_id is not None} technique_ids = {t.technique_id for t in tags if t.mitre_kind == "technique" and t.technique_id is not None} sub_ids = {t.subtechnique_id for t in tags if t.mitre_kind == "subtechnique" and t.subtechnique_id is not None} tactic_map: dict[uuid.UUID, MitreTactic] = {} technique_map: dict[uuid.UUID, MitreTechnique] = {} sub_map: dict[uuid.UUID, MitreSubtechnique] = {} if tactic_ids: tactic_map = {row.id: row for row in s.scalars(select(MitreTactic).where(MitreTactic.id.in_(tactic_ids))).all()} if technique_ids: technique_map = { row.id: row for row in s.scalars(select(MitreTechnique).where(MitreTechnique.id.in_(technique_ids))).all() } if sub_ids: sub_map = { row.id: row for row in s.scalars(select(MitreSubtechnique).where(MitreSubtechnique.id.in_(sub_ids))).all() } views: list[MitreTagView] = [] for tag in tags: if tag.mitre_kind == "tactic" and tag.tactic_id in tactic_map: row_t = tactic_map[tag.tactic_id] views.append(MitreTagView(kind="tactic", external_id=row_t.external_id, name=row_t.name, url=row_t.url)) elif tag.mitre_kind == "technique" and tag.technique_id in technique_map: row_te = technique_map[tag.technique_id] views.append(MitreTagView(kind="technique", external_id=row_te.external_id, name=row_te.name, url=row_te.url)) elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id in sub_map: row_sb = sub_map[tag.subtechnique_id] views.append(MitreTagView(kind="subtechnique", external_id=row_sb.external_id, name=row_sb.name, url=row_sb.url)) views.sort(key=lambda v: (v.kind, v.external_id)) return views def _to_views_batch(s: Session, templates: list[TestTemplate]) -> list[TestTemplateView]: """List-level batcher: one bulk MITRE resolve for all templates' tags. For a list of K templates with ~T tags each, this issues 3 queries total (one per MITRE kind) instead of 3K. We build (kind, uuid) → row maps once, then assemble each template's view in memory. """ tactic_ids: set[uuid.UUID] = set() technique_ids: set[uuid.UUID] = set() sub_ids: set[uuid.UUID] = set() for t in templates: for tag in t.mitre_tags: if tag.mitre_kind == "tactic" and tag.tactic_id is not None: tactic_ids.add(tag.tactic_id) elif tag.mitre_kind == "technique" and tag.technique_id is not None: technique_ids.add(tag.technique_id) elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None: sub_ids.add(tag.subtechnique_id) tactic_map: dict[uuid.UUID, MitreTactic] = ( {row.id: row for row in s.scalars(select(MitreTactic).where(MitreTactic.id.in_(tactic_ids))).all()} if tactic_ids else {} ) technique_map: dict[uuid.UUID, MitreTechnique] = ( {row.id: row for row in s.scalars(select(MitreTechnique).where(MitreTechnique.id.in_(technique_ids))).all()} if technique_ids else {} ) sub_map: dict[uuid.UUID, MitreSubtechnique] = ( {row.id: row for row in s.scalars(select(MitreSubtechnique).where(MitreSubtechnique.id.in_(sub_ids))).all()} if sub_ids else {} ) def _views_for(tags: list[TestTemplateMitreTag]) -> list[MitreTagView]: out: list[MitreTagView] = [] for tag in tags: if tag.mitre_kind == "tactic" and tag.tactic_id in tactic_map: row_t = tactic_map[tag.tactic_id] out.append(MitreTagView(kind="tactic", external_id=row_t.external_id, name=row_t.name, url=row_t.url)) elif tag.mitre_kind == "technique" and tag.technique_id in technique_map: row_te = technique_map[tag.technique_id] out.append(MitreTagView(kind="technique", external_id=row_te.external_id, name=row_te.name, url=row_te.url)) elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id in sub_map: row_sb = sub_map[tag.subtechnique_id] out.append(MitreTagView(kind="subtechnique", external_id=row_sb.external_id, name=row_sb.name, url=row_sb.url)) out.sort(key=lambda v: (v.kind, v.external_id)) return out views: list[TestTemplateView] = [] for t in templates: views.append( TestTemplateView( id=t.id, name=t.name, description=t.description, objective=t.objective, procedure_md=t.procedure_md, prerequisites_md=t.prerequisites_md, expected_result_red_md=t.expected_result_red_md, expected_detection_blue_md=t.expected_detection_blue_md, opsec_level=t.opsec_level, tags=list(t.tags or []), expected_iocs=list(t.expected_iocs or []), mitre_tags=_views_for(list(t.mitre_tags)), deleted_at=t.deleted_at, created_at=t.created_at, updated_at=t.updated_at, ) ) return views def _to_view(s: Session, t: TestTemplate) -> TestTemplateView: tag_views = _resolve_mitre_views(s, list(t.mitre_tags)) return TestTemplateView( id=t.id, name=t.name, description=t.description, objective=t.objective, procedure_md=t.procedure_md, prerequisites_md=t.prerequisites_md, expected_result_red_md=t.expected_result_red_md, expected_detection_blue_md=t.expected_detection_blue_md, opsec_level=t.opsec_level, tags=list(t.tags or []), expected_iocs=list(t.expected_iocs or []), mitre_tags=tag_views, deleted_at=t.deleted_at, created_at=t.created_at, updated_at=t.updated_at, ) def _base_query(): return select(TestTemplate).options(selectinload(TestTemplate.mitre_tags)) def list_test_templates( *, q: str | None = None, tactic: str | None = None, # external_id like "TA0006" technique: str | None = None, subtechnique: str | None = None, opsec_level: str | None = None, tag: str | None = None, include_deleted: bool = False, limit: int = 100, offset: int = 0, ) -> tuple[list[TestTemplateView], int]: with session_scope() as s: stmt = _base_query().order_by(TestTemplate.name.asc()) count_stmt = select(func.count()).select_from(TestTemplate) if not include_deleted: stmt = stmt.where(TestTemplate.deleted_at.is_(None)) count_stmt = count_stmt.where(TestTemplate.deleted_at.is_(None)) if q: like = f"%{q.lower()}%" cond = or_( func.lower(TestTemplate.name).like(like), func.lower(TestTemplate.description).like(like), ) stmt = stmt.where(cond) count_stmt = count_stmt.where(cond) if opsec_level: _validate_opsec(opsec_level) stmt = stmt.where(TestTemplate.opsec_level == opsec_level) count_stmt = count_stmt.where(TestTemplate.opsec_level == opsec_level) if tag: stmt = stmt.where(TestTemplate.tags.any(tag)) count_stmt = count_stmt.where(TestTemplate.tags.any(tag)) # MITRE facets: each provided facet (tactic, technique, subtechnique) is # AND-combined — a template tagged BOTH `TA0006` AND `T1003` matches a # query with `?tactic=TA0006&technique=T1003`, but a template tagged # only `TA0006` does NOT. Each facet matches strictly its own column # (no cross-column UUID collision risk). def _facet_subquery(column, mitre_id: uuid.UUID): return ( select(TestTemplateMitreTag.test_template_id) .where(column == mitre_id) .distinct() ) if tactic: tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic)) if tac is None: return [], 0 sub_q = _facet_subquery(TestTemplateMitreTag.tactic_id, tac.id) stmt = stmt.where(TestTemplate.id.in_(sub_q)) count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) if technique: tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique)) if tech is None: return [], 0 sub_q = _facet_subquery(TestTemplateMitreTag.technique_id, tech.id) stmt = stmt.where(TestTemplate.id.in_(sub_q)) count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) if subtechnique: sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique)) if sub is None: return [], 0 sub_q = _facet_subquery(TestTemplateMitreTag.subtechnique_id, sub.id) stmt = stmt.where(TestTemplate.id.in_(sub_q)) count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() return _to_views_batch(s, list(rows)), int(total) def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView: with session_scope() as s: t = s.get(TestTemplate, template_id) if t is None: raise TestTemplateNotFound() if t.deleted_at is not None and not include_deleted: raise TestTemplateNotFound() return _to_view(s, t) def create_test_template( *, name: str, description: str | None = None, objective: str | None = None, procedure_md: str | None = None, prerequisites_md: str | None = None, expected_result_red_md: str | None = None, expected_detection_blue_md: str | None = None, opsec_level: str = "medium", tags: list[str] | None = None, expected_iocs: list[str] | None = None, mitre_tags: list[MitreTagRef] | None = None, ) -> TestTemplateView: name_norm = (name or "").strip() if not name_norm: raise ValueError("name is required") _validate_opsec(opsec_level) norm_tags = _normalize_string_list(tags) norm_iocs = _normalize_string_list(expected_iocs) with session_scope() as s: t = TestTemplate( name=name_norm, description=_opt_str(description), objective=_opt_str(objective), procedure_md=procedure_md or None, prerequisites_md=prerequisites_md or None, expected_result_red_md=expected_result_red_md or None, expected_detection_blue_md=expected_detection_blue_md or None, opsec_level=opsec_level, tags=norm_tags, expected_iocs=norm_iocs, ) s.add(t) s.flush() if mitre_tags: rows = _resolve_mitre_refs(s, mitre_tags) for row in rows: row.test_template_id = t.id s.add(row) s.flush() s.refresh(t) return _to_view(s, t) def _opt_str(value: str | None) -> str | None: if value is None: return None s = value.strip() return s or None def update_test_template( template_id: uuid.UUID, *, name: str | None = None, description: Any = _UNSET, objective: Any = _UNSET, procedure_md: Any = _UNSET, prerequisites_md: Any = _UNSET, expected_result_red_md: Any = _UNSET, expected_detection_blue_md: Any = _UNSET, opsec_level: str | None = None, tags: Any = _UNSET, expected_iocs: Any = _UNSET, mitre_tags: Any = _UNSET, ) -> TestTemplateView: with session_scope() as s: t = s.get(TestTemplate, template_id) if t is None or t.deleted_at is not None: raise TestTemplateNotFound() if name is not None: n = name.strip() if not n: raise ValueError("name cannot be empty") t.name = n if description is not _UNSET: t.description = _opt_str(description) if objective is not _UNSET: t.objective = _opt_str(objective) if procedure_md is not _UNSET: t.procedure_md = procedure_md or None if prerequisites_md is not _UNSET: t.prerequisites_md = prerequisites_md or None if expected_result_red_md is not _UNSET: t.expected_result_red_md = expected_result_red_md or None if expected_detection_blue_md is not _UNSET: t.expected_detection_blue_md = expected_detection_blue_md or None if opsec_level is not None: _validate_opsec(opsec_level) t.opsec_level = opsec_level if tags is not _UNSET: t.tags = _normalize_string_list(tags) if expected_iocs is not _UNSET: t.expected_iocs = _normalize_string_list(expected_iocs) if mitre_tags is not _UNSET: for row in list(t.mitre_tags): s.delete(row) s.flush() rows = _resolve_mitre_refs(s, list(mitre_tags or [])) for row in rows: row.test_template_id = t.id s.add(row) s.flush() s.refresh(t) return _to_view(s, t) def soft_delete_test_template(template_id: uuid.UUID) -> None: with session_scope() as s: t = s.get(TestTemplate, template_id) if t is None or t.deleted_at is not None: raise TestTemplateNotFound() t.deleted_at = datetime.now(tz=timezone.utc)