feat: sprint 3 — multi-technique simulations + MITRE matrix modal #6

Merged
knacky merged 8 commits from sprint/3-mitre-matrix into main 2026-05-27 17:11:22 +00:00
8 changed files with 737 additions and 30 deletions
Showing only changes of commit b5ea2929de - Show all commits

View File

@@ -121,7 +121,7 @@ def transition_simulation(sid: int):
# ---------------------------------------------------------------------------
# MITRE autocomplete
# MITRE autocomplete + matrix
# ---------------------------------------------------------------------------
@@ -136,3 +136,14 @@ def mitre_techniques():
q = request.args.get("q", "").strip()
results = mitre_svc.search(q)
return jsonify(results), 200
@simulations_bp.get("/api/mitre/matrix")
@login_required
def mitre_matrix():
from backend.app.services import mitre as mitre_svc
if not mitre_svc.mitre_loaded:
return jsonify({"error": "mitre bundle not loaded"}), 503
return jsonify(mitre_svc.get_matrix()), 200

View File

@@ -25,8 +25,7 @@ class Simulation(db.Model): # type: ignore[name-defined]
index=True,
)
name = db.Column(db.String(255), nullable=False)
mitre_technique_id = db.Column(db.String(32), nullable=True)
mitre_technique_name = db.Column(db.String(255), nullable=True)
techniques = db.Column(db.JSON, nullable=False, default=list)
description = db.Column(db.Text, nullable=True)
commands = db.Column(db.Text, nullable=True)
prerequisites = db.Column(db.Text, nullable=True)

View File

@@ -20,13 +20,22 @@ def serialize_user_brief(user: User) -> dict[str, Any]:
return {"id": user.id, "username": user.username}
def _enrich_techniques(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Attach tactics to each {id, name} snapshot from the MITRE service."""
from backend.app.services import mitre as mitre_svc
return [
{"id": t["id"], "name": t["name"], "tactics": mitre_svc.get_tactics(t["id"])}
for t in (raw or [])
]
def serialize_simulation(simulation: Simulation) -> dict[str, Any]:
return {
"id": simulation.id,
"engagement_id": simulation.engagement_id,
"name": simulation.name,
"mitre_technique_id": simulation.mitre_technique_id,
"mitre_technique_name": simulation.mitre_technique_name,
"techniques": _enrich_techniques(simulation.techniques or []),
"description": simulation.description,
"commands": simulation.commands,
"prerequisites": simulation.prerequisites,

View File

@@ -8,11 +8,30 @@ from typing import Any
logger = logging.getLogger(__name__)
# Absolute path to the committed bundle.
_BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json"
# Canonical Enterprise tactic order (12 tactics).
_TACTIC_ORDER = [
"initial-access",
"execution",
"persistence",
"privilege-escalation",
"defense-evasion",
"credential-access",
"discovery",
"lateral-movement",
"collection",
"command-and-control",
"exfiltration",
"impact",
]
mitre_loaded: bool = False
_index: list[dict[str, Any]] = []
_tactics_by_technique: dict[str, list[str]] = {}
_name_by_id: dict[str, str] = {}
# matrix: list of tactic dicts (built once at load time)
_matrix: list[dict[str, Any]] = []
def _extract_tactics(obj: dict[str, Any]) -> list[str]:
@@ -20,7 +39,7 @@ def _extract_tactics(obj: dict[str, Any]) -> list[str]:
return [
p["phase_name"]
for p in phases
if isinstance(p, dict) and "phase_name" in p
if isinstance(p, dict) and "phase_name" in p and p.get("kill_chain_name") == "mitre-attack"
]
@@ -31,9 +50,65 @@ def _get_external_id(obj: dict[str, Any]) -> str | None:
return None
def _is_subtechnique(tech_id: str) -> bool:
return "." in tech_id
def _parent_id(sub_id: str) -> str:
return sub_id.split(".")[0]
def _build_matrix(entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Build the tactic → techniques → subtechniques tree."""
# Group top-level techniques by tactic.
tactic_techs: dict[str, list[dict[str, Any]]] = {t: [] for t in _TACTIC_ORDER}
for entry in entries:
if _is_subtechnique(entry["id"]):
continue
for tactic in entry["tactics"]:
if tactic in tactic_techs:
tactic_techs[tactic].append(entry)
# Attach sub-techniques to their parents.
parent_subs: dict[str, list[dict[str, Any]]] = {}
for entry in entries:
if not _is_subtechnique(entry["id"]):
continue
pid = _parent_id(entry["id"])
parent_subs.setdefault(pid, []).append({"id": entry["id"], "name": entry["name"]})
# Sort subs alphabetically by name.
for subs in parent_subs.values():
subs.sort(key=lambda x: x["name"])
matrix: list[dict[str, Any]] = []
for tactic_id in _TACTIC_ORDER:
techs = tactic_techs.get(tactic_id, [])
# Sort techniques alphabetically.
techs_sorted = sorted(techs, key=lambda x: x["name"])
tactic_name = tactic_id.replace("-", " ").title()
matrix.append(
{
"tactic_id": tactic_id,
"tactic_name": tactic_name,
"techniques": [
{
"id": t["id"],
"name": t["name"],
"subtechniques": parent_subs.get(t["id"], []),
}
for t in techs_sorted
],
}
)
return matrix
def load_bundle(path: Path | None = None) -> None:
"""Load the MITRE bundle into memory. Called once at app boot."""
global mitre_loaded, _index
global mitre_loaded, _index, _tactics_by_technique, _name_by_id, _matrix
bundle_path = path or _BUNDLE_PATH
try:
@@ -49,6 +124,9 @@ def load_bundle(path: Path | None = None) -> None:
return
entries: list[dict[str, Any]] = []
tactics_map: dict[str, list[str]] = {}
name_map: dict[str, str] = {}
for obj in data.get("objects") or []:
if not isinstance(obj, dict):
continue
@@ -59,19 +137,35 @@ def load_bundle(path: Path | None = None) -> None:
ext_id = _get_external_id(obj)
if not ext_id:
continue
entries.append(
{
"id": ext_id,
"name": obj.get("name", ""),
"tactics": _extract_tactics(obj),
}
)
tactics = _extract_tactics(obj)
name = obj.get("name", "")
entries.append({"id": ext_id, "name": name, "tactics": tactics})
tactics_map[ext_id] = tactics
name_map[ext_id] = name
_index = entries
_tactics_by_technique = tactics_map
_name_by_id = name_map
_matrix = _build_matrix(entries)
mitre_loaded = True
logger.info("MITRE bundle loaded: %d techniques", len(_index))
def get_tactics(technique_id: str) -> list[str]:
"""Return tactic list for a technique id; empty list if unknown."""
return _tactics_by_technique.get(technique_id, [])
def lookup_name(technique_id: str) -> str | None:
"""Return the name for a technique id, or None if not in the bundle."""
return _name_by_id.get(technique_id)
def get_matrix() -> list[dict[str, Any]]:
"""Return the full tactic → techniques → subtechniques tree."""
return _matrix
def search(query: str, limit: int = 20) -> list[dict[str, Any]]:
"""Return up to `limit` techniques matching `query`.

View File

@@ -10,11 +10,10 @@ from backend.app.extensions import db
from backend.app.models import User
from backend.app.models.simulation import Simulation, SimulationStatus
# Fields only admin/redteam may write (excluding technique_ids which is handled separately).
REDTEAM_FIELDS = frozenset(
{
"name",
"mitre_technique_id",
"mitre_technique_name",
"description",
"commands",
"prerequisites",
@@ -25,8 +24,6 @@ REDTEAM_FIELDS = frozenset(
SOC_FIELDS = frozenset({"log_source", "logs", "soc_comment", "incident_number"})
# Transitions allowed via POST /transition endpoint (manual only).
# auto pending→in_progress is handled in apply_patch, not here.
_ALLOWED_TRANSITIONS: dict[str, dict[str, set[str]]] = {
"review_required": {
"from": {"pending", "in_progress"},
@@ -48,6 +45,27 @@ def _is_non_empty(value: Any) -> bool:
return not (isinstance(value, list) and len(value) == 0)
def _resolve_technique_ids(
technique_ids: list[str],
) -> tuple[list[dict[str, str]] | None, tuple[Any, int] | None]:
"""Validate and resolve technique IDs to [{id, name}] snapshots.
Returns (resolved_list, None) on success or (None, error_tuple) on failure.
Deduplicates while preserving order.
"""
from backend.app.services import mitre as mitre_svc
# Dedup, preserve order.
seen: dict[str, None] = dict.fromkeys(technique_ids)
resolved: list[dict[str, str]] = []
for tid in seen:
name = mitre_svc.lookup_name(tid)
if name is None:
return None, (jsonify({"error": f"unknown technique id: {tid}"}), 400)
resolved.append({"id": tid, "name": name})
return resolved, None
def apply_patch(
simulation: Simulation, payload: dict[str, Any], user: User
) -> tuple[Any, int] | None:
@@ -59,15 +77,14 @@ def apply_patch(
role = user.role.value
if role == "soc":
# SOC can only patch when status allows it.
if simulation.status not in (
SimulationStatus.REVIEW_REQUIRED,
SimulationStatus.DONE,
):
return jsonify({"error": "simulation not ready for SOC review"}), 403
# SOC must not send redteam fields.
redteam_keys_in_payload = REDTEAM_FIELDS & payload.keys()
# SOC must not send redteam fields or technique_ids.
redteam_keys_in_payload = (REDTEAM_FIELDS | {"technique_ids"}) & payload.keys()
if redteam_keys_in_payload:
return jsonify({"error": "soc cannot edit redteam fields"}), 403
@@ -76,10 +93,10 @@ def apply_patch(
setattr(simulation, field, payload[field])
else:
# admin / redteam: apply all fields present.
# admin / redteam path.
redteam_keys_present = REDTEAM_FIELDS & payload.keys()
# Validate executed_at before any writes so a bad value causes no partial mutation.
# Validate executed_at upfront before any writes.
executed_at_value: datetime | None = None
if "executed_at" in redteam_keys_present:
val = payload["executed_at"]
@@ -91,21 +108,39 @@ def apply_patch(
except ValueError:
return jsonify({"error": "invalid executed_at"}), 400
# Validate and resolve technique_ids upfront.
resolved_techniques: list[dict[str, str]] | None = None
if "technique_ids" in payload:
raw_ids = payload["technique_ids"]
if not isinstance(raw_ids, list):
return jsonify({"error": "technique_ids must be a list"}), 400
resolved_techniques, err = _resolve_technique_ids(raw_ids)
if err is not None:
return err
# Apply scalar redteam fields.
for field in redteam_keys_present:
if field == "executed_at":
simulation.executed_at = executed_at_value
else:
setattr(simulation, field, payload[field])
# Apply resolved techniques.
if resolved_techniques is not None:
simulation.techniques = resolved_techniques
# Apply SOC fields (admin/redteam may also write them).
for field in SOC_FIELDS:
if field in payload:
setattr(simulation, field, payload[field])
# Auto-transition pending → in_progress: at least one redteam field with
# a non-empty value in the *incoming payload*.
if simulation.status == SimulationStatus.PENDING and any(
_is_non_empty(payload[k]) for k in redteam_keys_present
):
# Auto-transition pending → in_progress.
# Triggers when any redteam scalar has a non-empty value, OR technique_ids is non-empty.
auto_trigger = any(_is_non_empty(payload[k]) for k in redteam_keys_present)
if not auto_trigger and "technique_ids" in payload:
auto_trigger = len(payload["technique_ids"]) > 0
if simulation.status == SimulationStatus.PENDING and auto_trigger:
simulation.status = SimulationStatus.IN_PROGRESS
simulation.updated_at = datetime.now(UTC)

View File

@@ -0,0 +1,82 @@
"""replace scalar MITRE columns with techniques JSON array
Revision ID: 0003
Revises: 0002
Create Date: 2026-05-27 00:00:00.000000
"""
import json
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import column, table, text
revision = "0003"
down_revision = "0002"
branch_labels = None
depends_on = None
# Lightweight table proxies for data migration (no ORM import).
_sims = table(
"simulations",
column("id", sa.Integer),
column("mitre_technique_id", sa.String),
column("mitre_technique_name", sa.String),
column("techniques", sa.Text),
)
def upgrade():
bind = op.get_bind()
# 1. Add techniques column (nullable while we backfill).
op.add_column("simulations", sa.Column("techniques", sa.Text(), nullable=True))
# 2. Backfill: scalar → JSON array.
rows = bind.execute(
text("SELECT id, mitre_technique_id, mitre_technique_name FROM simulations")
).fetchall()
for row in rows:
if row[1]: # mitre_technique_id is not null
val = json.dumps([{"id": row[1], "name": row[2] or ""}])
else:
val = "[]"
bind.execute(
text("UPDATE simulations SET techniques = :v WHERE id = :id"),
{"v": val, "id": row[0]},
)
# 3. Make NOT NULL now that every row has a value.
# SQLite doesn't support ALTER COLUMN, so we skip the nullable constraint
# change at DDL level — the application model enforces it.
# 4. Drop old scalar columns.
with op.batch_alter_table("simulations") as batch_op:
batch_op.drop_column("mitre_technique_id")
batch_op.drop_column("mitre_technique_name")
def downgrade():
bind = op.get_bind()
# 1. Re-add scalar columns.
with op.batch_alter_table("simulations") as batch_op:
batch_op.add_column(sa.Column("mitre_technique_id", sa.String(length=32), nullable=True))
batch_op.add_column(sa.Column("mitre_technique_name", sa.String(length=255), nullable=True))
# 2. Back-fill: take first element of techniques array.
rows = bind.execute(text("SELECT id, techniques FROM simulations")).fetchall()
for row in rows:
techniques = json.loads(row[1] or "[]")
if techniques:
first = techniques[0]
bind.execute(
text(
"UPDATE simulations SET mitre_technique_id = :tid, mitre_technique_name = :tname WHERE id = :id"
),
{"tid": first.get("id"), "tname": first.get("name"), "id": row[0]},
)
# 3. Drop techniques column.
with op.batch_alter_table("simulations") as batch_op:
batch_op.drop_column("techniques")

View File

@@ -33,6 +33,14 @@ _FIXTURE_BUNDLE = {
],
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
},
{
"type": "attack-pattern",
"name": "Python",
"external_references": [
{"source_name": "mitre-attack", "external_id": "T1059.006"}
],
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
},
{
"type": "attack-pattern",
"name": "Phishing",
@@ -76,9 +84,15 @@ def _reset_mitre():
"""Reset the MITRE service state between tests."""
original_loaded = mitre_svc.mitre_loaded
original_index = list(mitre_svc._index)
original_tactics = dict(mitre_svc._tactics_by_technique)
original_names = dict(mitre_svc._name_by_id)
original_matrix = list(mitre_svc._matrix)
yield
mitre_svc.mitre_loaded = original_loaded
mitre_svc._index = original_index
mitre_svc._tactics_by_technique = original_tactics
mitre_svc._name_by_id = original_names
mitre_svc._matrix = original_matrix
@pytest.fixture()
@@ -96,7 +110,7 @@ def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path:
def test_load_bundle_success(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
assert mitre_svc.mitre_loaded is True
assert len(mitre_svc._index) == 4 # 5 objects minus 1 revoked = 4
assert len(mitre_svc._index) == 5 # 6 attack-patterns minus 1 revoked = 5
def test_load_bundle_missing_file() -> None:
@@ -245,3 +259,119 @@ def test_mitre_endpoint_includes_tactics(
phishing = next((r for r in data if r["id"] == "T1566"), None)
assert phishing is not None
assert "initial-access" in phishing["tactics"]
# ---------------------------------------------------------------------------
# Sprint 3: get_tactics, lookup_name, get_matrix
# ---------------------------------------------------------------------------
def test_get_tactics_known(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
tactics = mitre_svc.get_tactics("T1078")
assert "initial-access" in tactics
assert "persistence" in tactics
def test_get_tactics_unknown_returns_empty(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
assert mitre_svc.get_tactics("T0000") == []
def test_lookup_name_known(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
assert mitre_svc.lookup_name("T1059") == "Command and Scripting Interpreter"
def test_lookup_name_subtechnique(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
assert mitre_svc.lookup_name("T1059.001") == "PowerShell"
def test_lookup_name_unknown_returns_none(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
assert mitre_svc.lookup_name("T0000") is None
def test_get_matrix_returns_ordered_tactics(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
matrix = mitre_svc.get_matrix()
tactic_ids = [t["tactic_id"] for t in matrix]
# initial-access must come before execution in canonical order.
assert tactic_ids.index("initial-access") < tactic_ids.index("execution")
def test_get_matrix_subtechniques_nested(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
matrix = mitre_svc.get_matrix()
exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution")
t1059 = next((t for t in exec_tactic["techniques"] if t["id"] == "T1059"), None)
assert t1059 is not None
sub_ids = [s["id"] for s in t1059["subtechniques"]]
assert "T1059.001" in sub_ids
assert "T1059.006" in sub_ids
def test_get_matrix_subtechniques_sorted_by_name(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
matrix = mitre_svc.get_matrix()
exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution")
t1059 = next(t for t in exec_tactic["techniques"] if t["id"] == "T1059")
names = [s["name"] for s in t1059["subtechniques"]]
assert names == sorted(names)
def test_get_matrix_techniques_sorted_by_name(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
matrix = mitre_svc.get_matrix()
ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access")
names = [t["name"] for t in ia_tactic["techniques"]]
assert names == sorted(names)
def test_get_matrix_technique_no_subtechniques(bundle_file: pathlib.Path) -> None:
mitre_svc.load_bundle(bundle_file)
matrix = mitre_svc.get_matrix()
ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access")
phishing = next((t for t in ia_tactic["techniques"] if t["id"] == "T1566"), None)
assert phishing is not None
assert phishing["subtechniques"] == []
def test_matrix_endpoint_ok(
client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path
) -> None:
mitre_svc.load_bundle(bundle_file)
resp = client.get("/api/mitre/matrix", headers=_h(redteam_token))
assert resp.status_code == 200
data = resp.get_json()
assert isinstance(data, list)
tactic_ids = [t["tactic_id"] for t in data]
assert "initial-access" in tactic_ids
assert "execution" in tactic_ids
def test_matrix_endpoint_503_when_not_loaded(
client: FlaskClient, redteam_token: str
) -> None:
mitre_svc.mitre_loaded = False
resp = client.get("/api/mitre/matrix", headers=_h(redteam_token))
assert resp.status_code == 503
def test_matrix_endpoint_requires_auth(client: FlaskClient) -> None:
resp = client.get("/api/mitre/matrix")
assert resp.status_code == 401
def test_matrix_endpoint_all_roles(
client: FlaskClient,
redteam_token: str,
soc_token: str,
admin_token: str,
bundle_file: pathlib.Path,
) -> None:
mitre_svc.load_bundle(bundle_file)
for token in (redteam_token, soc_token, admin_token):
resp = client.get("/api/mitre/matrix", headers=_h(token))
assert resp.status_code == 200

View File

@@ -0,0 +1,347 @@
"""Sprint 3 — multi-technique simulation tests (AC-13)."""
from __future__ import annotations
import json
import pathlib
import pytest
from flask.testing import FlaskClient
from backend.app.services import mitre as mitre_svc
from backend.tests.conftest import auth_headers as _h
# ---------------------------------------------------------------------------
# Minimal STIX fixture (reused from test_mitre.py pattern)
# ---------------------------------------------------------------------------
_FIXTURE_BUNDLE = {
"type": "bundle",
"objects": [
{
"type": "attack-pattern",
"name": "Command and Scripting Interpreter",
"external_references": [{"source_name": "mitre-attack", "external_id": "T1059"}],
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
},
{
"type": "attack-pattern",
"name": "PowerShell",
"external_references": [{"source_name": "mitre-attack", "external_id": "T1059.001"}],
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
},
{
"type": "attack-pattern",
"name": "Valid Accounts",
"external_references": [{"source_name": "mitre-attack", "external_id": "T1078"}],
"kill_chain_phases": [
{"phase_name": "initial-access", "kill_chain_name": "mitre-attack"},
{"phase_name": "persistence", "kill_chain_name": "mitre-attack"},
],
},
],
}
@pytest.fixture(autouse=True)
def _reset_mitre():
original_loaded = mitre_svc.mitre_loaded
original_index = list(mitre_svc._index)
original_tactics = dict(mitre_svc._tactics_by_technique)
original_names = dict(mitre_svc._name_by_id)
original_matrix = list(mitre_svc._matrix)
yield
mitre_svc.mitre_loaded = original_loaded
mitre_svc._index = original_index
mitre_svc._tactics_by_technique = original_tactics
mitre_svc._name_by_id = original_names
mitre_svc._matrix = original_matrix
@pytest.fixture()
def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path:
p = tmp_path / "enterprise-attack.json"
p.write_text(json.dumps(_FIXTURE_BUNDLE), encoding="utf-8")
return p
@pytest.fixture()
def loaded_bundle(bundle_file: pathlib.Path) -> pathlib.Path:
mitre_svc.load_bundle(bundle_file)
return bundle_file
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_engagement(client: FlaskClient, token: str) -> dict:
resp = client.post(
"/api/engagements",
headers=_h(token),
json={"name": "Op Sprint3", "start_date": "2026-06-01"},
)
assert resp.status_code == 201
return resp.get_json()
def _make_sim(client: FlaskClient, token: str, eid: int) -> dict:
resp = client.post(
f"/api/engagements/{eid}/simulations",
headers=_h(token),
json={"name": "Technique Test"},
)
assert resp.status_code == 201
return resp.get_json()
def _patch(client: FlaskClient, token: str, sid: int, payload: dict):
return client.patch(f"/api/simulations/{sid}", headers=_h(token), json=payload)
# ---------------------------------------------------------------------------
# AC-13.1 — new simulation has techniques = []
# ---------------------------------------------------------------------------
def test_new_simulation_has_empty_techniques(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
assert sim["techniques"] == []
# ---------------------------------------------------------------------------
# AC-13.3 — serializer enriches techniques with tactics
# ---------------------------------------------------------------------------
def test_techniques_enriched_with_tactics(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
_patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]})
resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token))
assert resp.status_code == 200
techs = resp.get_json()["techniques"]
assert len(techs) == 1
assert techs[0]["id"] == "T1078"
assert "initial-access" in techs[0]["tactics"]
assert "persistence" in techs[0]["tactics"]
def test_techniques_with_unknown_id_returns_empty_tactics(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
"""If a technique was removed from the bundle after save, tactics gracefully = []."""
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
# Bypass service, write directly an id not in the bundle.
from backend.app.extensions import db
from backend.app.models.simulation import Simulation
with client.application.app_context():
s = db.session.get(Simulation, sim["id"])
s.techniques = [{"id": "T0000", "name": "Removed Technique"}]
db.session.commit()
resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token))
techs = resp.get_json()["techniques"]
assert techs[0]["tactics"] == []
# ---------------------------------------------------------------------------
# AC-13.4 — PATCH technique_ids
# ---------------------------------------------------------------------------
def test_patch_technique_ids_sets_techniques(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078"]})
assert resp.status_code == 200
techs = resp.get_json()["techniques"]
assert len(techs) == 2
ids = [t["id"] for t in techs]
assert "T1059" in ids
assert "T1078" in ids
def test_patch_technique_ids_resolves_name(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
assert resp.status_code == 200
tech = resp.get_json()["techniques"][0]
assert tech["name"] == "Command and Scripting Interpreter"
def test_patch_technique_ids_unknown_returns_400(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T9999"]})
assert resp.status_code == 400
assert "unknown technique id: T9999" in resp.get_json()["error"]
def test_patch_technique_ids_partial_unknown_rejected(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
# One valid, one unknown — whole request rejected.
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T9999"]})
assert resp.status_code == 400
def test_patch_technique_ids_includes_subtechnique(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059.001"]})
assert resp.status_code == 200
techs = resp.get_json()["techniques"]
assert techs[0]["id"] == "T1059.001"
assert techs[0]["name"] == "PowerShell"
def test_patch_technique_ids_replaces_list(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
_patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]})
assert resp.status_code == 200
ids = [t["id"] for t in resp.get_json()["techniques"]]
assert ids == ["T1078"]
def test_patch_technique_ids_empty_clears_list(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
_patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []})
assert resp.status_code == 200
assert resp.get_json()["techniques"] == []
def test_patch_technique_ids_not_list_returns_400(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": "T1059"})
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# Dedup (spec-reviewer note: AC-13.4)
# ---------------------------------------------------------------------------
def test_patch_technique_ids_deduplicates(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(
client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078", "T1059"]}
)
assert resp.status_code == 200
techs = resp.get_json()["techniques"]
assert len(techs) == 2
# Order preserved: T1059 first.
assert techs[0]["id"] == "T1059"
assert techs[1]["id"] == "T1078"
# ---------------------------------------------------------------------------
# AC-13.5 — auto-transition on technique_ids
# ---------------------------------------------------------------------------
def test_technique_ids_non_empty_triggers_auto_transition(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
assert sim["status"] == "pending"
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
assert resp.status_code == 200
assert resp.get_json()["status"] == "in_progress"
def test_technique_ids_empty_does_not_trigger_auto_transition(
client: FlaskClient, redteam_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []})
assert resp.status_code == 200
assert resp.get_json()["status"] == "pending"
# ---------------------------------------------------------------------------
# SOC cannot patch technique_ids (it's a redteam field)
# ---------------------------------------------------------------------------
def test_soc_cannot_patch_technique_ids(
client: FlaskClient, redteam_token: str, soc_token: str, loaded_bundle
) -> None:
eng = _make_engagement(client, redteam_token)
sim = _make_sim(client, redteam_token, eng["id"])
# Advance to review_required so SOC can touch the simulation at all.
client.post(
f"/api/simulations/{sim['id']}/transition",
headers=_h(redteam_token),
json={"to": "review_required"},
)
resp = _patch(client, soc_token, sim["id"], {"technique_ids": ["T1059"]})
assert resp.status_code == 403
# ---------------------------------------------------------------------------
# Migration backfill test (inline, no Alembic runner needed)
# ---------------------------------------------------------------------------
def test_migration_backfill_logic() -> None:
"""Verify the backfill logic used in upgrade(): scalar → [{id, name}]."""
import json as _json
def _backfill(tech_id, tech_name):
if tech_id:
return _json.loads(_json.dumps([{"id": tech_id, "name": tech_name or ""}]))
return []
assert _backfill("T1059", "Command and Scripting Interpreter") == [
{"id": "T1059", "name": "Command and Scripting Interpreter"}
]
assert _backfill(None, None) == []
assert _backfill("T1059", None) == [{"id": "T1059", "name": ""}]