"""CRUD service for `test_templates` + their MITRE tags. The MITRE tag set is **fully replaced** on every update — partial mutation of the join rows would force the API client to track tag UUIDs they never created. The polymorphic join (one of `tactic_id` / `technique_id` / `subtechnique_id` populated) is owned here: callers pass `(kind, external_id)` tuples and we resolve them to the matching MITRE row. """ from __future__ import annotations import uuid from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Iterable from sqlalchemy import func, or_, select from sqlalchemy.orm import Session, selectinload _UNSET: Any = object() from app.db.session import session_scope from app.db.types import MITRE_KINDS, OPSEC_LEVELS from app.models.mitre import MitreSubtechnique, MitreTactic, MitreTechnique from app.models.template import TestTemplate, TestTemplateMitreTag class TestTemplateNotFound(Exception): pass class UnknownMitreTag(Exception): """Raised when an (kind, external_id) tuple doesn't resolve to a known MITRE row.""" @dataclass(frozen=True) class MitreTagRef: """Inbound MITRE tag reference. `external_id` is the ATT&CK identifier (TA…/T…/T….…) — we resolve it server-side, the client never sees UUIDs. """ kind: str # "tactic" | "technique" | "subtechnique" external_id: str @dataclass(frozen=True) class MitreTagView: kind: str external_id: str name: str url: str | None @dataclass(frozen=True) class TestTemplateView: id: uuid.UUID name: str description: str | None objective: str | None procedure_md: str | None prerequisites_md: str | None expected_result_red_md: str | None expected_detection_blue_md: str | None opsec_level: str tags: list[str] expected_iocs: list[str] mitre_tags: list[MitreTagView] deleted_at: datetime | None created_at: datetime updated_at: datetime def _validate_opsec(value: str) -> str: if value not in OPSEC_LEVELS: raise ValueError(f"opsec_level must be one of {OPSEC_LEVELS}") return value def _normalize_string_list(values: Iterable[str] | None) -> list[str]: if not values: return [] seen: set[str] = set() out: list[str] = [] for raw in values: if not isinstance(raw, str): raise ValueError("list items must be strings") v = raw.strip() if not v or v in seen: continue seen.add(v) out.append(v) return out def _resolve_mitre_refs(s: Session, refs: list[MitreTagRef]) -> list[TestTemplateMitreTag]: """Translate `(kind, external_id)` pairs into half-populated join rows. Validates that: - `kind` is one of the supported values - each external_id resolves to an existing MITRE row - the combination is unique inside the payload (de-duped silently — same tag twice is a no-op, not an error) """ if not refs: return [] # Dedupe input deduped: dict[tuple[str, str], MitreTagRef] = {} for ref in refs: if ref.kind not in MITRE_KINDS: raise ValueError(f"mitre tag kind must be one of {MITRE_KINDS}") if not ref.external_id: raise ValueError("mitre tag external_id is required") deduped[(ref.kind, ref.external_id)] = ref tactic_ids = {r.external_id for r in deduped.values() if r.kind == "tactic"} technique_ids = {r.external_id for r in deduped.values() if r.kind == "technique"} subtechnique_ids = {r.external_id for r in deduped.values() if r.kind == "subtechnique"} tactic_map = { t.external_id: t.id for t in s.scalars(select(MitreTactic).where(MitreTactic.external_id.in_(tactic_ids))).all() } technique_map = { t.external_id: t.id for t in s.scalars(select(MitreTechnique).where(MitreTechnique.external_id.in_(technique_ids))).all() } subtechnique_map = { sb.external_id: sb.id for sb in s.scalars( select(MitreSubtechnique).where(MitreSubtechnique.external_id.in_(subtechnique_ids)) ).all() } rows: list[TestTemplateMitreTag] = [] missing: list[tuple[str, str]] = [] for ref in deduped.values(): if ref.kind == "tactic": mid = tactic_map.get(ref.external_id) if mid is None: missing.append((ref.kind, ref.external_id)) continue rows.append(TestTemplateMitreTag(mitre_kind="tactic", tactic_id=mid)) elif ref.kind == "technique": mid = technique_map.get(ref.external_id) if mid is None: missing.append((ref.kind, ref.external_id)) continue rows.append(TestTemplateMitreTag(mitre_kind="technique", technique_id=mid)) else: mid = subtechnique_map.get(ref.external_id) if mid is None: missing.append((ref.kind, ref.external_id)) continue rows.append(TestTemplateMitreTag(mitre_kind="subtechnique", subtechnique_id=mid)) if missing: raise UnknownMitreTag(f"unknown MITRE tags: {sorted(missing)}") return rows def _to_view(s: Session, t: TestTemplate) -> TestTemplateView: tag_views: list[MitreTagView] = [] for tag in t.mitre_tags: if tag.mitre_kind == "tactic" and tag.tactic_id is not None: row = s.get(MitreTactic, tag.tactic_id) if row is not None: tag_views.append(MitreTagView(kind="tactic", external_id=row.external_id, name=row.name, url=row.url)) elif tag.mitre_kind == "technique" and tag.technique_id is not None: row = s.get(MitreTechnique, tag.technique_id) if row is not None: tag_views.append(MitreTagView(kind="technique", external_id=row.external_id, name=row.name, url=row.url)) elif tag.mitre_kind == "subtechnique" and tag.subtechnique_id is not None: row = s.get(MitreSubtechnique, tag.subtechnique_id) if row is not None: tag_views.append( MitreTagView(kind="subtechnique", external_id=row.external_id, name=row.name, url=row.url) ) tag_views.sort(key=lambda v: (v.kind, v.external_id)) return TestTemplateView( id=t.id, name=t.name, description=t.description, objective=t.objective, procedure_md=t.procedure_md, prerequisites_md=t.prerequisites_md, expected_result_red_md=t.expected_result_red_md, expected_detection_blue_md=t.expected_detection_blue_md, opsec_level=t.opsec_level, tags=list(t.tags or []), expected_iocs=list(t.expected_iocs or []), mitre_tags=tag_views, deleted_at=t.deleted_at, created_at=t.created_at, updated_at=t.updated_at, ) def _base_query(): return select(TestTemplate).options(selectinload(TestTemplate.mitre_tags)) def list_test_templates( *, q: str | None = None, tactic: str | None = None, # external_id like "TA0006" technique: str | None = None, subtechnique: str | None = None, opsec_level: str | None = None, tag: str | None = None, include_deleted: bool = False, limit: int = 100, offset: int = 0, ) -> tuple[list[TestTemplateView], int]: with session_scope() as s: stmt = _base_query().order_by(TestTemplate.name.asc()) count_stmt = select(func.count()).select_from(TestTemplate) if not include_deleted: stmt = stmt.where(TestTemplate.deleted_at.is_(None)) count_stmt = count_stmt.where(TestTemplate.deleted_at.is_(None)) if q: like = f"%{q.lower()}%" cond = or_( func.lower(TestTemplate.name).like(like), func.lower(TestTemplate.description).like(like), ) stmt = stmt.where(cond) count_stmt = count_stmt.where(cond) if opsec_level: _validate_opsec(opsec_level) stmt = stmt.where(TestTemplate.opsec_level == opsec_level) count_stmt = count_stmt.where(TestTemplate.opsec_level == opsec_level) if tag: stmt = stmt.where(TestTemplate.tags.any(tag)) count_stmt = count_stmt.where(TestTemplate.tags.any(tag)) # MITRE facet: resolve external_id → uuid then filter via join subquery. if tactic or technique or subtechnique: tag_ids: list[uuid.UUID] = [] if tactic: tac = s.scalar(select(MitreTactic).where(MitreTactic.external_id == tactic)) if tac is None: return [], 0 tag_ids.append(tac.id) if technique: tech = s.scalar(select(MitreTechnique).where(MitreTechnique.external_id == technique)) if tech is None: return [], 0 tag_ids.append(tech.id) if subtechnique: sub = s.scalar(select(MitreSubtechnique).where(MitreSubtechnique.external_id == subtechnique)) if sub is None: return [], 0 tag_ids.append(sub.id) sub_q = ( select(TestTemplateMitreTag.test_template_id) .where( or_( TestTemplateMitreTag.tactic_id.in_(tag_ids), TestTemplateMitreTag.technique_id.in_(tag_ids), TestTemplateMitreTag.subtechnique_id.in_(tag_ids), ) ) .distinct() ) stmt = stmt.where(TestTemplate.id.in_(sub_q)) count_stmt = count_stmt.where(TestTemplate.id.in_(sub_q)) total = s.scalar(count_stmt) or 0 rows = s.scalars(stmt.limit(max(1, min(limit, 500))).offset(max(0, offset))).all() return [_to_view(s, t) for t in rows], int(total) def get_test_template(template_id: uuid.UUID, *, include_deleted: bool = False) -> TestTemplateView: with session_scope() as s: t = s.get(TestTemplate, template_id) if t is None: raise TestTemplateNotFound() if t.deleted_at is not None and not include_deleted: raise TestTemplateNotFound() return _to_view(s, t) def create_test_template( *, name: str, description: str | None = None, objective: str | None = None, procedure_md: str | None = None, prerequisites_md: str | None = None, expected_result_red_md: str | None = None, expected_detection_blue_md: str | None = None, opsec_level: str = "medium", tags: list[str] | None = None, expected_iocs: list[str] | None = None, mitre_tags: list[MitreTagRef] | None = None, ) -> TestTemplateView: name_norm = (name or "").strip() if not name_norm: raise ValueError("name is required") _validate_opsec(opsec_level) norm_tags = _normalize_string_list(tags) norm_iocs = _normalize_string_list(expected_iocs) with session_scope() as s: t = TestTemplate( name=name_norm, description=_opt_str(description), objective=_opt_str(objective), procedure_md=procedure_md or None, prerequisites_md=prerequisites_md or None, expected_result_red_md=expected_result_red_md or None, expected_detection_blue_md=expected_detection_blue_md or None, opsec_level=opsec_level, tags=norm_tags, expected_iocs=norm_iocs, ) s.add(t) s.flush() if mitre_tags: rows = _resolve_mitre_refs(s, mitre_tags) for row in rows: row.test_template_id = t.id s.add(row) s.flush() s.refresh(t) return _to_view(s, t) def _opt_str(value: str | None) -> str | None: if value is None: return None s = value.strip() return s or None def update_test_template( template_id: uuid.UUID, *, name: str | None = None, description: Any = _UNSET, objective: Any = _UNSET, procedure_md: Any = _UNSET, prerequisites_md: Any = _UNSET, expected_result_red_md: Any = _UNSET, expected_detection_blue_md: Any = _UNSET, opsec_level: str | None = None, tags: Any = _UNSET, expected_iocs: Any = _UNSET, mitre_tags: Any = _UNSET, ) -> TestTemplateView: with session_scope() as s: t = s.get(TestTemplate, template_id) if t is None or t.deleted_at is not None: raise TestTemplateNotFound() if name is not None: n = name.strip() if not n: raise ValueError("name cannot be empty") t.name = n if description is not _UNSET: t.description = _opt_str(description) if objective is not _UNSET: t.objective = _opt_str(objective) if procedure_md is not _UNSET: t.procedure_md = procedure_md or None if prerequisites_md is not _UNSET: t.prerequisites_md = prerequisites_md or None if expected_result_red_md is not _UNSET: t.expected_result_red_md = expected_result_red_md or None if expected_detection_blue_md is not _UNSET: t.expected_detection_blue_md = expected_detection_blue_md or None if opsec_level is not None: _validate_opsec(opsec_level) t.opsec_level = opsec_level if tags is not _UNSET: t.tags = _normalize_string_list(tags) if expected_iocs is not _UNSET: t.expected_iocs = _normalize_string_list(expected_iocs) if mitre_tags is not _UNSET: for row in list(t.mitre_tags): s.delete(row) s.flush() rows = _resolve_mitre_refs(s, list(mitre_tags or [])) for row in rows: row.test_template_id = t.id s.add(row) s.flush() s.refresh(t) return _to_view(s, t) def soft_delete_test_template(template_id: uuid.UUID) -> None: with session_scope() as s: t = s.get(TestTemplate, template_id) if t is None or t.deleted_at is not None: raise TestTemplateNotFound() t.deleted_at = datetime.now(tz=timezone.utc)