feat: sprint 3 — multi-technique simulations + MITRE matrix modal #6
@@ -121,7 +121,7 @@ def transition_simulation(sid: int):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MITRE autocomplete
|
||||
# MITRE autocomplete + matrix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -136,3 +136,14 @@ def mitre_techniques():
|
||||
q = request.args.get("q", "").strip()
|
||||
results = mitre_svc.search(q)
|
||||
return jsonify(results), 200
|
||||
|
||||
|
||||
@simulations_bp.get("/api/mitre/matrix")
|
||||
@login_required
|
||||
def mitre_matrix():
|
||||
from backend.app.services import mitre as mitre_svc
|
||||
|
||||
if not mitre_svc.mitre_loaded:
|
||||
return jsonify({"error": "mitre bundle not loaded"}), 503
|
||||
|
||||
return jsonify(mitre_svc.get_matrix()), 200
|
||||
|
||||
@@ -25,8 +25,7 @@ class Simulation(db.Model): # type: ignore[name-defined]
|
||||
index=True,
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
mitre_technique_id = db.Column(db.String(32), nullable=True)
|
||||
mitre_technique_name = db.Column(db.String(255), nullable=True)
|
||||
techniques = db.Column(db.JSON, nullable=False, default=list)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
commands = db.Column(db.Text, nullable=True)
|
||||
prerequisites = db.Column(db.Text, nullable=True)
|
||||
|
||||
@@ -20,13 +20,22 @@ def serialize_user_brief(user: User) -> dict[str, Any]:
|
||||
return {"id": user.id, "username": user.username}
|
||||
|
||||
|
||||
def _enrich_techniques(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Attach tactics to each {id, name} snapshot from the MITRE service."""
|
||||
from backend.app.services import mitre as mitre_svc
|
||||
|
||||
return [
|
||||
{"id": t["id"], "name": t["name"], "tactics": mitre_svc.get_tactics(t["id"])}
|
||||
for t in (raw or [])
|
||||
]
|
||||
|
||||
|
||||
def serialize_simulation(simulation: Simulation) -> dict[str, Any]:
|
||||
return {
|
||||
"id": simulation.id,
|
||||
"engagement_id": simulation.engagement_id,
|
||||
"name": simulation.name,
|
||||
"mitre_technique_id": simulation.mitre_technique_id,
|
||||
"mitre_technique_name": simulation.mitre_technique_name,
|
||||
"techniques": _enrich_techniques(simulation.techniques or []),
|
||||
"description": simulation.description,
|
||||
"commands": simulation.commands,
|
||||
"prerequisites": simulation.prerequisites,
|
||||
|
||||
@@ -8,11 +8,30 @@ from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Absolute path to the committed bundle.
|
||||
_BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json"
|
||||
|
||||
# Canonical Enterprise tactic order (12 tactics).
|
||||
_TACTIC_ORDER = [
|
||||
"initial-access",
|
||||
"execution",
|
||||
"persistence",
|
||||
"privilege-escalation",
|
||||
"defense-evasion",
|
||||
"credential-access",
|
||||
"discovery",
|
||||
"lateral-movement",
|
||||
"collection",
|
||||
"command-and-control",
|
||||
"exfiltration",
|
||||
"impact",
|
||||
]
|
||||
|
||||
mitre_loaded: bool = False
|
||||
_index: list[dict[str, Any]] = []
|
||||
_tactics_by_technique: dict[str, list[str]] = {}
|
||||
_name_by_id: dict[str, str] = {}
|
||||
# matrix: list of tactic dicts (built once at load time)
|
||||
_matrix: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
def _extract_tactics(obj: dict[str, Any]) -> list[str]:
|
||||
@@ -20,7 +39,7 @@ def _extract_tactics(obj: dict[str, Any]) -> list[str]:
|
||||
return [
|
||||
p["phase_name"]
|
||||
for p in phases
|
||||
if isinstance(p, dict) and "phase_name" in p
|
||||
if isinstance(p, dict) and "phase_name" in p and p.get("kill_chain_name") == "mitre-attack"
|
||||
]
|
||||
|
||||
|
||||
@@ -31,9 +50,65 @@ def _get_external_id(obj: dict[str, Any]) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _is_subtechnique(tech_id: str) -> bool:
|
||||
return "." in tech_id
|
||||
|
||||
|
||||
def _parent_id(sub_id: str) -> str:
|
||||
return sub_id.split(".")[0]
|
||||
|
||||
|
||||
def _build_matrix(entries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Build the tactic → techniques → subtechniques tree."""
|
||||
# Group top-level techniques by tactic.
|
||||
tactic_techs: dict[str, list[dict[str, Any]]] = {t: [] for t in _TACTIC_ORDER}
|
||||
|
||||
for entry in entries:
|
||||
if _is_subtechnique(entry["id"]):
|
||||
continue
|
||||
for tactic in entry["tactics"]:
|
||||
if tactic in tactic_techs:
|
||||
tactic_techs[tactic].append(entry)
|
||||
|
||||
# Attach sub-techniques to their parents.
|
||||
parent_subs: dict[str, list[dict[str, Any]]] = {}
|
||||
for entry in entries:
|
||||
if not _is_subtechnique(entry["id"]):
|
||||
continue
|
||||
pid = _parent_id(entry["id"])
|
||||
parent_subs.setdefault(pid, []).append({"id": entry["id"], "name": entry["name"]})
|
||||
|
||||
# Sort subs alphabetically by name.
|
||||
for subs in parent_subs.values():
|
||||
subs.sort(key=lambda x: x["name"])
|
||||
|
||||
matrix: list[dict[str, Any]] = []
|
||||
for tactic_id in _TACTIC_ORDER:
|
||||
techs = tactic_techs.get(tactic_id, [])
|
||||
# Sort techniques alphabetically.
|
||||
techs_sorted = sorted(techs, key=lambda x: x["name"])
|
||||
tactic_name = tactic_id.replace("-", " ").title()
|
||||
matrix.append(
|
||||
{
|
||||
"tactic_id": tactic_id,
|
||||
"tactic_name": tactic_name,
|
||||
"techniques": [
|
||||
{
|
||||
"id": t["id"],
|
||||
"name": t["name"],
|
||||
"subtechniques": parent_subs.get(t["id"], []),
|
||||
}
|
||||
for t in techs_sorted
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def load_bundle(path: Path | None = None) -> None:
|
||||
"""Load the MITRE bundle into memory. Called once at app boot."""
|
||||
global mitre_loaded, _index
|
||||
global mitre_loaded, _index, _tactics_by_technique, _name_by_id, _matrix
|
||||
bundle_path = path or _BUNDLE_PATH
|
||||
|
||||
try:
|
||||
@@ -49,6 +124,9 @@ def load_bundle(path: Path | None = None) -> None:
|
||||
return
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
tactics_map: dict[str, list[str]] = {}
|
||||
name_map: dict[str, str] = {}
|
||||
|
||||
for obj in data.get("objects") or []:
|
||||
if not isinstance(obj, dict):
|
||||
continue
|
||||
@@ -59,19 +137,35 @@ def load_bundle(path: Path | None = None) -> None:
|
||||
ext_id = _get_external_id(obj)
|
||||
if not ext_id:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"id": ext_id,
|
||||
"name": obj.get("name", ""),
|
||||
"tactics": _extract_tactics(obj),
|
||||
}
|
||||
)
|
||||
tactics = _extract_tactics(obj)
|
||||
name = obj.get("name", "")
|
||||
entries.append({"id": ext_id, "name": name, "tactics": tactics})
|
||||
tactics_map[ext_id] = tactics
|
||||
name_map[ext_id] = name
|
||||
|
||||
_index = entries
|
||||
_tactics_by_technique = tactics_map
|
||||
_name_by_id = name_map
|
||||
_matrix = _build_matrix(entries)
|
||||
mitre_loaded = True
|
||||
logger.info("MITRE bundle loaded: %d techniques", len(_index))
|
||||
|
||||
|
||||
def get_tactics(technique_id: str) -> list[str]:
|
||||
"""Return tactic list for a technique id; empty list if unknown."""
|
||||
return _tactics_by_technique.get(technique_id, [])
|
||||
|
||||
|
||||
def lookup_name(technique_id: str) -> str | None:
|
||||
"""Return the name for a technique id, or None if not in the bundle."""
|
||||
return _name_by_id.get(technique_id)
|
||||
|
||||
|
||||
def get_matrix() -> list[dict[str, Any]]:
|
||||
"""Return the full tactic → techniques → subtechniques tree."""
|
||||
return _matrix
|
||||
|
||||
|
||||
def search(query: str, limit: int = 20) -> list[dict[str, Any]]:
|
||||
"""Return up to `limit` techniques matching `query`.
|
||||
|
||||
|
||||
@@ -10,11 +10,10 @@ from backend.app.extensions import db
|
||||
from backend.app.models import User
|
||||
from backend.app.models.simulation import Simulation, SimulationStatus
|
||||
|
||||
# Fields only admin/redteam may write (excluding technique_ids which is handled separately).
|
||||
REDTEAM_FIELDS = frozenset(
|
||||
{
|
||||
"name",
|
||||
"mitre_technique_id",
|
||||
"mitre_technique_name",
|
||||
"description",
|
||||
"commands",
|
||||
"prerequisites",
|
||||
@@ -25,8 +24,6 @@ REDTEAM_FIELDS = frozenset(
|
||||
|
||||
SOC_FIELDS = frozenset({"log_source", "logs", "soc_comment", "incident_number"})
|
||||
|
||||
# Transitions allowed via POST /transition endpoint (manual only).
|
||||
# auto pending→in_progress is handled in apply_patch, not here.
|
||||
_ALLOWED_TRANSITIONS: dict[str, dict[str, set[str]]] = {
|
||||
"review_required": {
|
||||
"from": {"pending", "in_progress"},
|
||||
@@ -48,6 +45,27 @@ def _is_non_empty(value: Any) -> bool:
|
||||
return not (isinstance(value, list) and len(value) == 0)
|
||||
|
||||
|
||||
def _resolve_technique_ids(
|
||||
technique_ids: list[str],
|
||||
) -> tuple[list[dict[str, str]] | None, tuple[Any, int] | None]:
|
||||
"""Validate and resolve technique IDs to [{id, name}] snapshots.
|
||||
|
||||
Returns (resolved_list, None) on success or (None, error_tuple) on failure.
|
||||
Deduplicates while preserving order.
|
||||
"""
|
||||
from backend.app.services import mitre as mitre_svc
|
||||
|
||||
# Dedup, preserve order.
|
||||
seen: dict[str, None] = dict.fromkeys(technique_ids)
|
||||
resolved: list[dict[str, str]] = []
|
||||
for tid in seen:
|
||||
name = mitre_svc.lookup_name(tid)
|
||||
if name is None:
|
||||
return None, (jsonify({"error": f"unknown technique id: {tid}"}), 400)
|
||||
resolved.append({"id": tid, "name": name})
|
||||
return resolved, None
|
||||
|
||||
|
||||
def apply_patch(
|
||||
simulation: Simulation, payload: dict[str, Any], user: User
|
||||
) -> tuple[Any, int] | None:
|
||||
@@ -59,15 +77,14 @@ def apply_patch(
|
||||
role = user.role.value
|
||||
|
||||
if role == "soc":
|
||||
# SOC can only patch when status allows it.
|
||||
if simulation.status not in (
|
||||
SimulationStatus.REVIEW_REQUIRED,
|
||||
SimulationStatus.DONE,
|
||||
):
|
||||
return jsonify({"error": "simulation not ready for SOC review"}), 403
|
||||
|
||||
# SOC must not send redteam fields.
|
||||
redteam_keys_in_payload = REDTEAM_FIELDS & payload.keys()
|
||||
# SOC must not send redteam fields or technique_ids.
|
||||
redteam_keys_in_payload = (REDTEAM_FIELDS | {"technique_ids"}) & payload.keys()
|
||||
if redteam_keys_in_payload:
|
||||
return jsonify({"error": "soc cannot edit redteam fields"}), 403
|
||||
|
||||
@@ -76,10 +93,10 @@ def apply_patch(
|
||||
setattr(simulation, field, payload[field])
|
||||
|
||||
else:
|
||||
# admin / redteam: apply all fields present.
|
||||
# admin / redteam path.
|
||||
redteam_keys_present = REDTEAM_FIELDS & payload.keys()
|
||||
|
||||
# Validate executed_at before any writes so a bad value causes no partial mutation.
|
||||
# Validate executed_at upfront before any writes.
|
||||
executed_at_value: datetime | None = None
|
||||
if "executed_at" in redteam_keys_present:
|
||||
val = payload["executed_at"]
|
||||
@@ -91,21 +108,39 @@ def apply_patch(
|
||||
except ValueError:
|
||||
return jsonify({"error": "invalid executed_at"}), 400
|
||||
|
||||
# Validate and resolve technique_ids upfront.
|
||||
resolved_techniques: list[dict[str, str]] | None = None
|
||||
if "technique_ids" in payload:
|
||||
raw_ids = payload["technique_ids"]
|
||||
if not isinstance(raw_ids, list):
|
||||
return jsonify({"error": "technique_ids must be a list"}), 400
|
||||
resolved_techniques, err = _resolve_technique_ids(raw_ids)
|
||||
if err is not None:
|
||||
return err
|
||||
|
||||
# Apply scalar redteam fields.
|
||||
for field in redteam_keys_present:
|
||||
if field == "executed_at":
|
||||
simulation.executed_at = executed_at_value
|
||||
else:
|
||||
setattr(simulation, field, payload[field])
|
||||
|
||||
# Apply resolved techniques.
|
||||
if resolved_techniques is not None:
|
||||
simulation.techniques = resolved_techniques
|
||||
|
||||
# Apply SOC fields (admin/redteam may also write them).
|
||||
for field in SOC_FIELDS:
|
||||
if field in payload:
|
||||
setattr(simulation, field, payload[field])
|
||||
|
||||
# Auto-transition pending → in_progress: at least one redteam field with
|
||||
# a non-empty value in the *incoming payload*.
|
||||
if simulation.status == SimulationStatus.PENDING and any(
|
||||
_is_non_empty(payload[k]) for k in redteam_keys_present
|
||||
):
|
||||
# Auto-transition pending → in_progress.
|
||||
# Triggers when any redteam scalar has a non-empty value, OR technique_ids is non-empty.
|
||||
auto_trigger = any(_is_non_empty(payload[k]) for k in redteam_keys_present)
|
||||
if not auto_trigger and "technique_ids" in payload:
|
||||
auto_trigger = len(payload["technique_ids"]) > 0
|
||||
|
||||
if simulation.status == SimulationStatus.PENDING and auto_trigger:
|
||||
simulation.status = SimulationStatus.IN_PROGRESS
|
||||
|
||||
simulation.updated_at = datetime.now(UTC)
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""replace scalar MITRE columns with techniques JSON array
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2026-05-27 00:00:00.000000
|
||||
"""
|
||||
import json
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import column, table, text
|
||||
|
||||
|
||||
revision = "0003"
|
||||
down_revision = "0002"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Lightweight table proxies for data migration (no ORM import).
|
||||
_sims = table(
|
||||
"simulations",
|
||||
column("id", sa.Integer),
|
||||
column("mitre_technique_id", sa.String),
|
||||
column("mitre_technique_name", sa.String),
|
||||
column("techniques", sa.Text),
|
||||
)
|
||||
|
||||
|
||||
def upgrade():
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1. Add techniques column (nullable while we backfill).
|
||||
op.add_column("simulations", sa.Column("techniques", sa.Text(), nullable=True))
|
||||
|
||||
# 2. Backfill: scalar → JSON array.
|
||||
rows = bind.execute(
|
||||
text("SELECT id, mitre_technique_id, mitre_technique_name FROM simulations")
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
if row[1]: # mitre_technique_id is not null
|
||||
val = json.dumps([{"id": row[1], "name": row[2] or ""}])
|
||||
else:
|
||||
val = "[]"
|
||||
bind.execute(
|
||||
text("UPDATE simulations SET techniques = :v WHERE id = :id"),
|
||||
{"v": val, "id": row[0]},
|
||||
)
|
||||
|
||||
# 3. Make NOT NULL now that every row has a value.
|
||||
# SQLite doesn't support ALTER COLUMN, so we skip the nullable constraint
|
||||
# change at DDL level — the application model enforces it.
|
||||
|
||||
# 4. Drop old scalar columns.
|
||||
with op.batch_alter_table("simulations") as batch_op:
|
||||
batch_op.drop_column("mitre_technique_id")
|
||||
batch_op.drop_column("mitre_technique_name")
|
||||
|
||||
|
||||
def downgrade():
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1. Re-add scalar columns.
|
||||
with op.batch_alter_table("simulations") as batch_op:
|
||||
batch_op.add_column(sa.Column("mitre_technique_id", sa.String(length=32), nullable=True))
|
||||
batch_op.add_column(sa.Column("mitre_technique_name", sa.String(length=255), nullable=True))
|
||||
|
||||
# 2. Back-fill: take first element of techniques array.
|
||||
rows = bind.execute(text("SELECT id, techniques FROM simulations")).fetchall()
|
||||
for row in rows:
|
||||
techniques = json.loads(row[1] or "[]")
|
||||
if techniques:
|
||||
first = techniques[0]
|
||||
bind.execute(
|
||||
text(
|
||||
"UPDATE simulations SET mitre_technique_id = :tid, mitre_technique_name = :tname WHERE id = :id"
|
||||
),
|
||||
{"tid": first.get("id"), "tname": first.get("name"), "id": row[0]},
|
||||
)
|
||||
|
||||
# 3. Drop techniques column.
|
||||
with op.batch_alter_table("simulations") as batch_op:
|
||||
batch_op.drop_column("techniques")
|
||||
@@ -33,6 +33,14 @@ _FIXTURE_BUNDLE = {
|
||||
],
|
||||
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"name": "Python",
|
||||
"external_references": [
|
||||
{"source_name": "mitre-attack", "external_id": "T1059.006"}
|
||||
],
|
||||
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"name": "Phishing",
|
||||
@@ -76,9 +84,15 @@ def _reset_mitre():
|
||||
"""Reset the MITRE service state between tests."""
|
||||
original_loaded = mitre_svc.mitre_loaded
|
||||
original_index = list(mitre_svc._index)
|
||||
original_tactics = dict(mitre_svc._tactics_by_technique)
|
||||
original_names = dict(mitre_svc._name_by_id)
|
||||
original_matrix = list(mitre_svc._matrix)
|
||||
yield
|
||||
mitre_svc.mitre_loaded = original_loaded
|
||||
mitre_svc._index = original_index
|
||||
mitre_svc._tactics_by_technique = original_tactics
|
||||
mitre_svc._name_by_id = original_names
|
||||
mitre_svc._matrix = original_matrix
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -96,7 +110,7 @@ def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||
def test_load_bundle_success(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
assert mitre_svc.mitre_loaded is True
|
||||
assert len(mitre_svc._index) == 4 # 5 objects minus 1 revoked = 4
|
||||
assert len(mitre_svc._index) == 5 # 6 attack-patterns minus 1 revoked = 5
|
||||
|
||||
|
||||
def test_load_bundle_missing_file() -> None:
|
||||
@@ -245,3 +259,119 @@ def test_mitre_endpoint_includes_tactics(
|
||||
phishing = next((r for r in data if r["id"] == "T1566"), None)
|
||||
assert phishing is not None
|
||||
assert "initial-access" in phishing["tactics"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sprint 3: get_tactics, lookup_name, get_matrix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tactics_known(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
tactics = mitre_svc.get_tactics("T1078")
|
||||
assert "initial-access" in tactics
|
||||
assert "persistence" in tactics
|
||||
|
||||
|
||||
def test_get_tactics_unknown_returns_empty(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
assert mitre_svc.get_tactics("T0000") == []
|
||||
|
||||
|
||||
def test_lookup_name_known(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
assert mitre_svc.lookup_name("T1059") == "Command and Scripting Interpreter"
|
||||
|
||||
|
||||
def test_lookup_name_subtechnique(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
assert mitre_svc.lookup_name("T1059.001") == "PowerShell"
|
||||
|
||||
|
||||
def test_lookup_name_unknown_returns_none(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
assert mitre_svc.lookup_name("T0000") is None
|
||||
|
||||
|
||||
def test_get_matrix_returns_ordered_tactics(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
matrix = mitre_svc.get_matrix()
|
||||
tactic_ids = [t["tactic_id"] for t in matrix]
|
||||
# initial-access must come before execution in canonical order.
|
||||
assert tactic_ids.index("initial-access") < tactic_ids.index("execution")
|
||||
|
||||
|
||||
def test_get_matrix_subtechniques_nested(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
matrix = mitre_svc.get_matrix()
|
||||
exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution")
|
||||
t1059 = next((t for t in exec_tactic["techniques"] if t["id"] == "T1059"), None)
|
||||
assert t1059 is not None
|
||||
sub_ids = [s["id"] for s in t1059["subtechniques"]]
|
||||
assert "T1059.001" in sub_ids
|
||||
assert "T1059.006" in sub_ids
|
||||
|
||||
|
||||
def test_get_matrix_subtechniques_sorted_by_name(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
matrix = mitre_svc.get_matrix()
|
||||
exec_tactic = next(t for t in matrix if t["tactic_id"] == "execution")
|
||||
t1059 = next(t for t in exec_tactic["techniques"] if t["id"] == "T1059")
|
||||
names = [s["name"] for s in t1059["subtechniques"]]
|
||||
assert names == sorted(names)
|
||||
|
||||
|
||||
def test_get_matrix_techniques_sorted_by_name(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
matrix = mitre_svc.get_matrix()
|
||||
ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access")
|
||||
names = [t["name"] for t in ia_tactic["techniques"]]
|
||||
assert names == sorted(names)
|
||||
|
||||
|
||||
def test_get_matrix_technique_no_subtechniques(bundle_file: pathlib.Path) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
matrix = mitre_svc.get_matrix()
|
||||
ia_tactic = next(t for t in matrix if t["tactic_id"] == "initial-access")
|
||||
phishing = next((t for t in ia_tactic["techniques"] if t["id"] == "T1566"), None)
|
||||
assert phishing is not None
|
||||
assert phishing["subtechniques"] == []
|
||||
|
||||
|
||||
def test_matrix_endpoint_ok(
|
||||
client: FlaskClient, redteam_token: str, bundle_file: pathlib.Path
|
||||
) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
resp = client.get("/api/mitre/matrix", headers=_h(redteam_token))
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert isinstance(data, list)
|
||||
tactic_ids = [t["tactic_id"] for t in data]
|
||||
assert "initial-access" in tactic_ids
|
||||
assert "execution" in tactic_ids
|
||||
|
||||
|
||||
def test_matrix_endpoint_503_when_not_loaded(
|
||||
client: FlaskClient, redteam_token: str
|
||||
) -> None:
|
||||
mitre_svc.mitre_loaded = False
|
||||
resp = client.get("/api/mitre/matrix", headers=_h(redteam_token))
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
def test_matrix_endpoint_requires_auth(client: FlaskClient) -> None:
|
||||
resp = client.get("/api/mitre/matrix")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_matrix_endpoint_all_roles(
|
||||
client: FlaskClient,
|
||||
redteam_token: str,
|
||||
soc_token: str,
|
||||
admin_token: str,
|
||||
bundle_file: pathlib.Path,
|
||||
) -> None:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
for token in (redteam_token, soc_token, admin_token):
|
||||
resp = client.get("/api/mitre/matrix", headers=_h(token))
|
||||
assert resp.status_code == 200
|
||||
|
||||
347
backend/tests/test_simulations_techniques.py
Normal file
347
backend/tests/test_simulations_techniques.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Sprint 3 — multi-technique simulation tests (AC-13)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from backend.app.services import mitre as mitre_svc
|
||||
from backend.tests.conftest import auth_headers as _h
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal STIX fixture (reused from test_mitre.py pattern)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FIXTURE_BUNDLE = {
|
||||
"type": "bundle",
|
||||
"objects": [
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"name": "Command and Scripting Interpreter",
|
||||
"external_references": [{"source_name": "mitre-attack", "external_id": "T1059"}],
|
||||
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"name": "PowerShell",
|
||||
"external_references": [{"source_name": "mitre-attack", "external_id": "T1059.001"}],
|
||||
"kill_chain_phases": [{"phase_name": "execution", "kill_chain_name": "mitre-attack"}],
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"name": "Valid Accounts",
|
||||
"external_references": [{"source_name": "mitre-attack", "external_id": "T1078"}],
|
||||
"kill_chain_phases": [
|
||||
{"phase_name": "initial-access", "kill_chain_name": "mitre-attack"},
|
||||
{"phase_name": "persistence", "kill_chain_name": "mitre-attack"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_mitre():
|
||||
original_loaded = mitre_svc.mitre_loaded
|
||||
original_index = list(mitre_svc._index)
|
||||
original_tactics = dict(mitre_svc._tactics_by_technique)
|
||||
original_names = dict(mitre_svc._name_by_id)
|
||||
original_matrix = list(mitre_svc._matrix)
|
||||
yield
|
||||
mitre_svc.mitre_loaded = original_loaded
|
||||
mitre_svc._index = original_index
|
||||
mitre_svc._tactics_by_technique = original_tactics
|
||||
mitre_svc._name_by_id = original_names
|
||||
mitre_svc._matrix = original_matrix
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bundle_file(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||
p = tmp_path / "enterprise-attack.json"
|
||||
p.write_text(json.dumps(_FIXTURE_BUNDLE), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def loaded_bundle(bundle_file: pathlib.Path) -> pathlib.Path:
|
||||
mitre_svc.load_bundle(bundle_file)
|
||||
return bundle_file
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_engagement(client: FlaskClient, token: str) -> dict:
|
||||
resp = client.post(
|
||||
"/api/engagements",
|
||||
headers=_h(token),
|
||||
json={"name": "Op Sprint3", "start_date": "2026-06-01"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
return resp.get_json()
|
||||
|
||||
|
||||
def _make_sim(client: FlaskClient, token: str, eid: int) -> dict:
|
||||
resp = client.post(
|
||||
f"/api/engagements/{eid}/simulations",
|
||||
headers=_h(token),
|
||||
json={"name": "Technique Test"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
return resp.get_json()
|
||||
|
||||
|
||||
def _patch(client: FlaskClient, token: str, sid: int, payload: dict):
|
||||
return client.patch(f"/api/simulations/{sid}", headers=_h(token), json=payload)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AC-13.1 — new simulation has techniques = []
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_new_simulation_has_empty_techniques(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
assert sim["techniques"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AC-13.3 — serializer enriches techniques with tactics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_techniques_enriched_with_tactics(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
_patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]})
|
||||
|
||||
resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token))
|
||||
assert resp.status_code == 200
|
||||
techs = resp.get_json()["techniques"]
|
||||
assert len(techs) == 1
|
||||
assert techs[0]["id"] == "T1078"
|
||||
assert "initial-access" in techs[0]["tactics"]
|
||||
assert "persistence" in techs[0]["tactics"]
|
||||
|
||||
|
||||
def test_techniques_with_unknown_id_returns_empty_tactics(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
"""If a technique was removed from the bundle after save, tactics gracefully = []."""
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
# Bypass service, write directly an id not in the bundle.
|
||||
from backend.app.extensions import db
|
||||
from backend.app.models.simulation import Simulation
|
||||
|
||||
with client.application.app_context():
|
||||
s = db.session.get(Simulation, sim["id"])
|
||||
s.techniques = [{"id": "T0000", "name": "Removed Technique"}]
|
||||
db.session.commit()
|
||||
|
||||
resp = client.get(f"/api/simulations/{sim['id']}", headers=_h(redteam_token))
|
||||
techs = resp.get_json()["techniques"]
|
||||
assert techs[0]["tactics"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AC-13.4 — PATCH technique_ids
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_patch_technique_ids_sets_techniques(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078"]})
|
||||
assert resp.status_code == 200
|
||||
techs = resp.get_json()["techniques"]
|
||||
assert len(techs) == 2
|
||||
ids = [t["id"] for t in techs]
|
||||
assert "T1059" in ids
|
||||
assert "T1078" in ids
|
||||
|
||||
|
||||
def test_patch_technique_ids_resolves_name(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
|
||||
assert resp.status_code == 200
|
||||
tech = resp.get_json()["techniques"][0]
|
||||
assert tech["name"] == "Command and Scripting Interpreter"
|
||||
|
||||
|
||||
def test_patch_technique_ids_unknown_returns_400(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T9999"]})
|
||||
assert resp.status_code == 400
|
||||
assert "unknown technique id: T9999" in resp.get_json()["error"]
|
||||
|
||||
|
||||
def test_patch_technique_ids_partial_unknown_rejected(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
# One valid, one unknown — whole request rejected.
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T9999"]})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_patch_technique_ids_includes_subtechnique(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059.001"]})
|
||||
assert resp.status_code == 200
|
||||
techs = resp.get_json()["techniques"]
|
||||
assert techs[0]["id"] == "T1059.001"
|
||||
assert techs[0]["name"] == "PowerShell"
|
||||
|
||||
|
||||
def test_patch_technique_ids_replaces_list(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
_patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1078"]})
|
||||
assert resp.status_code == 200
|
||||
ids = [t["id"] for t in resp.get_json()["techniques"]]
|
||||
assert ids == ["T1078"]
|
||||
|
||||
|
||||
def test_patch_technique_ids_empty_clears_list(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
_patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []})
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["techniques"] == []
|
||||
|
||||
|
||||
def test_patch_technique_ids_not_list_returns_400(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": "T1059"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dedup (spec-reviewer note: AC-13.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_patch_technique_ids_deduplicates(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(
|
||||
client, redteam_token, sim["id"], {"technique_ids": ["T1059", "T1078", "T1059"]}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
techs = resp.get_json()["techniques"]
|
||||
assert len(techs) == 2
|
||||
# Order preserved: T1059 first.
|
||||
assert techs[0]["id"] == "T1059"
|
||||
assert techs[1]["id"] == "T1078"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AC-13.5 — auto-transition on technique_ids
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_technique_ids_non_empty_triggers_auto_transition(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
assert sim["status"] == "pending"
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": ["T1059"]})
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["status"] == "in_progress"
|
||||
|
||||
|
||||
def test_technique_ids_empty_does_not_trigger_auto_transition(
|
||||
client: FlaskClient, redteam_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
|
||||
resp = _patch(client, redteam_token, sim["id"], {"technique_ids": []})
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["status"] == "pending"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SOC cannot patch technique_ids (it's a redteam field)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_soc_cannot_patch_technique_ids(
|
||||
client: FlaskClient, redteam_token: str, soc_token: str, loaded_bundle
|
||||
) -> None:
|
||||
eng = _make_engagement(client, redteam_token)
|
||||
sim = _make_sim(client, redteam_token, eng["id"])
|
||||
# Advance to review_required so SOC can touch the simulation at all.
|
||||
client.post(
|
||||
f"/api/simulations/{sim['id']}/transition",
|
||||
headers=_h(redteam_token),
|
||||
json={"to": "review_required"},
|
||||
)
|
||||
|
||||
resp = _patch(client, soc_token, sim["id"], {"technique_ids": ["T1059"]})
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration backfill test (inline, no Alembic runner needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_migration_backfill_logic() -> None:
|
||||
"""Verify the backfill logic used in upgrade(): scalar → [{id, name}]."""
|
||||
import json as _json
|
||||
|
||||
def _backfill(tech_id, tech_name):
|
||||
if tech_id:
|
||||
return _json.loads(_json.dumps([{"id": tech_id, "name": tech_name or ""}]))
|
||||
return []
|
||||
|
||||
assert _backfill("T1059", "Command and Scripting Interpreter") == [
|
||||
{"id": "T1059", "name": "Command and Scripting Interpreter"}
|
||||
]
|
||||
assert _backfill(None, None) == []
|
||||
assert _backfill("T1059", None) == [{"id": "T1059", "name": ""}]
|
||||
Reference in New Issue
Block a user