feat(backend): sprint 2 — simulations + MITRE ATT&CK
- Simulation model with full field set (redteam + SOC sides) and cascade delete - Alembic migration 0002 for simulations table - simulation_workflow service: PATCH RBAC field-level + auto-transition pending→in_progress + state machine - mitre service: STIX bundle loader (boot-safe) + ranked search (exact-id > prefix-id > name) - 7 new API endpoints: list/create/get/patch/delete simulations, transition, MITRE autocomplete - serialize_simulation added to serializers.py - Makefile update-mitre target with real curl + optional docker restart - Dockerfile updated to copy backend/data/ into image - MITRE enterprise-attack.json bundle committed (~45 MB) - 67 new tests (total 130 passing), ruff clean, mypy introduces no new errors Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
from flask import Flask, jsonify, send_from_directory
|
||||
|
||||
from backend.app.api import auth_bp, engagements_bp, users_bp
|
||||
from backend.app.api import auth_bp, engagements_bp, simulations_bp, users_bp
|
||||
from backend.app.cli import register_cli
|
||||
from backend.app.config import Config, TestConfig
|
||||
from backend.app.errors import register_error_handlers
|
||||
@@ -36,6 +36,10 @@ def create_app(config_object: object | None = None) -> Flask:
|
||||
app.register_blueprint(auth_bp)
|
||||
app.register_blueprint(users_bp)
|
||||
app.register_blueprint(engagements_bp)
|
||||
app.register_blueprint(simulations_bp)
|
||||
|
||||
from backend.app.services import mitre as mitre_svc
|
||||
mitre_svc.load_bundle()
|
||||
|
||||
register_error_handlers(app)
|
||||
register_cli(app)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""API blueprints."""
|
||||
from backend.app.api.auth import auth_bp
|
||||
from backend.app.api.engagements import engagements_bp
|
||||
from backend.app.api.simulations import simulations_bp
|
||||
from backend.app.api.users import users_bp
|
||||
|
||||
__all__ = ["auth_bp", "users_bp", "engagements_bp"]
|
||||
__all__ = ["auth_bp", "users_bp", "engagements_bp", "simulations_bp"]
|
||||
|
||||
141
backend/app/api/simulations.py
Normal file
141
backend/app/api/simulations.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Simulation CRUD + workflow endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import Blueprint, g, jsonify, request
|
||||
|
||||
from backend.app.auth import login_required, role_required
|
||||
from backend.app.extensions import db
|
||||
from backend.app.models import Engagement
|
||||
from backend.app.models.simulation import Simulation, SimulationStatus
|
||||
from backend.app.serializers import serialize_simulation
|
||||
from backend.app.services import simulation_workflow
|
||||
|
||||
simulations_bp = Blueprint("simulations", __name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Nested under /api/engagements/<eid>/simulations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@simulations_bp.get("/api/engagements/<int:eid>/simulations")
|
||||
@login_required
|
||||
def list_simulations(eid: int):
|
||||
engagement = db.session.get(Engagement, eid)
|
||||
if engagement is None:
|
||||
return jsonify({"error": "Engagement not found"}), 404
|
||||
sims = (
|
||||
Simulation.query.filter_by(engagement_id=eid)
|
||||
.order_by(Simulation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return jsonify([serialize_simulation(s) for s in sims]), 200
|
||||
|
||||
|
||||
@simulations_bp.post("/api/engagements/<int:eid>/simulations")
|
||||
@role_required("admin", "redteam")
|
||||
def create_simulation(eid: int):
|
||||
engagement = db.session.get(Engagement, eid)
|
||||
if engagement is None:
|
||||
return jsonify({"error": "Engagement not found"}), 404
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
name = (data.get("name") or "").strip()
|
||||
if not name:
|
||||
return jsonify({"error": "name is required"}), 400
|
||||
|
||||
sim = Simulation(
|
||||
engagement_id=eid,
|
||||
name=name,
|
||||
status=SimulationStatus.PENDING,
|
||||
created_at=datetime.now(UTC),
|
||||
created_by_id=g.current_user.id,
|
||||
)
|
||||
db.session.add(sim)
|
||||
db.session.commit()
|
||||
return jsonify(serialize_simulation(sim)), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Flat /api/simulations/<sid>
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@simulations_bp.get("/api/simulations/<int:sid>")
|
||||
@login_required
|
||||
def get_simulation(sid: int):
|
||||
sim = db.session.get(Simulation, sid)
|
||||
if sim is None:
|
||||
return jsonify({"error": "Simulation not found"}), 404
|
||||
return jsonify(serialize_simulation(sim)), 200
|
||||
|
||||
|
||||
@simulations_bp.patch("/api/simulations/<int:sid>")
|
||||
@login_required
|
||||
def update_simulation(sid: int):
|
||||
sim = db.session.get(Simulation, sid)
|
||||
if sim is None:
|
||||
return jsonify({"error": "Simulation not found"}), 404
|
||||
|
||||
user = g.current_user
|
||||
if user.role.value not in ("admin", "redteam", "soc"):
|
||||
return jsonify({"error": "Forbidden"}), 403
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
if not data:
|
||||
return jsonify(serialize_simulation(sim)), 200
|
||||
|
||||
err = simulation_workflow.apply_patch(sim, data, user)
|
||||
if err is not None:
|
||||
return err
|
||||
|
||||
db.session.commit()
|
||||
return jsonify(serialize_simulation(sim)), 200
|
||||
|
||||
|
||||
@simulations_bp.delete("/api/simulations/<int:sid>")
|
||||
@role_required("admin", "redteam")
|
||||
def delete_simulation(sid: int):
|
||||
sim = db.session.get(Simulation, sid)
|
||||
if sim is None:
|
||||
return jsonify({"error": "Simulation not found"}), 404
|
||||
db.session.delete(sim)
|
||||
db.session.commit()
|
||||
return "", 204
|
||||
|
||||
|
||||
@simulations_bp.post("/api/simulations/<int:sid>/transition")
|
||||
@login_required
|
||||
def transition_simulation(sid: int):
|
||||
sim = db.session.get(Simulation, sid)
|
||||
if sim is None:
|
||||
return jsonify({"error": "Simulation not found"}), 404
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
to_status = data.get("to", "")
|
||||
|
||||
err = simulation_workflow.transition(sim, to_status, g.current_user)
|
||||
if err is not None:
|
||||
return err
|
||||
|
||||
return jsonify(serialize_simulation(sim)), 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MITRE autocomplete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@simulations_bp.get("/api/mitre/techniques")
|
||||
@login_required
|
||||
def mitre_techniques():
|
||||
from backend.app.services import mitre as mitre_svc
|
||||
|
||||
if not mitre_svc.mitre_loaded:
|
||||
return jsonify({"error": "mitre bundle not loaded"}), 503
|
||||
|
||||
q = request.args.get("q", "").strip()
|
||||
results = mitre_svc.search(q)
|
||||
return jsonify(results), 200
|
||||
@@ -1,5 +1,6 @@
|
||||
"""SQLAlchemy models."""
|
||||
from backend.app.models.engagement import Engagement, EngagementStatus
|
||||
from backend.app.models.simulation import Simulation, SimulationStatus
|
||||
from backend.app.models.user import User, UserRole
|
||||
|
||||
__all__ = ["User", "UserRole", "Engagement", "EngagementStatus"]
|
||||
__all__ = ["User", "UserRole", "Engagement", "EngagementStatus", "Simulation", "SimulationStatus"]
|
||||
|
||||
62
backend/app/models/simulation.py
Normal file
62
backend/app/models/simulation.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Simulation model."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.extensions import db
|
||||
|
||||
|
||||
class SimulationStatus(str, enum.Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
REVIEW_REQUIRED = "review_required"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
class Simulation(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "simulations"
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
engagement_id = db.Column(
|
||||
db.Integer,
|
||||
db.ForeignKey("engagements.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
mitre_technique_id = db.Column(db.String(32), nullable=True)
|
||||
mitre_technique_name = db.Column(db.String(255), nullable=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
commands = db.Column(db.Text, nullable=True)
|
||||
prerequisites = db.Column(db.Text, nullable=True)
|
||||
executed_at = db.Column(db.DateTime, nullable=True)
|
||||
execution_result = db.Column(db.Text, nullable=True)
|
||||
log_source = db.Column(db.Text, nullable=True)
|
||||
logs = db.Column(db.Text, nullable=True)
|
||||
soc_comment = db.Column(db.Text, nullable=True)
|
||||
incident_number = db.Column(db.String(128), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(SimulationStatus, name="simulation_status"),
|
||||
nullable=False,
|
||||
default=SimulationStatus.PENDING,
|
||||
)
|
||||
created_at = db.Column(
|
||||
db.DateTime, nullable=False, default=lambda: datetime.now(UTC)
|
||||
)
|
||||
updated_at = db.Column(db.DateTime, nullable=True)
|
||||
created_by_id = db.Column(
|
||||
db.Integer,
|
||||
db.ForeignKey("users.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
engagement = db.relationship(
|
||||
"Engagement",
|
||||
backref=db.backref("simulations", cascade="all, delete-orphan", lazy="dynamic"),
|
||||
)
|
||||
created_by = db.relationship("User", backref="simulations", lazy="joined")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Simulation {self.id} {self.name!r}>"
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from backend.app.models import Engagement, User
|
||||
from backend.app.models.simulation import Simulation
|
||||
|
||||
|
||||
def serialize_user(user: User) -> dict[str, Any]:
|
||||
@@ -19,6 +20,31 @@ def serialize_user_brief(user: User) -> dict[str, Any]:
|
||||
return {"id": user.id, "username": user.username}
|
||||
|
||||
|
||||
def serialize_simulation(simulation: Simulation) -> dict[str, Any]:
|
||||
return {
|
||||
"id": simulation.id,
|
||||
"engagement_id": simulation.engagement_id,
|
||||
"name": simulation.name,
|
||||
"mitre_technique_id": simulation.mitre_technique_id,
|
||||
"mitre_technique_name": simulation.mitre_technique_name,
|
||||
"description": simulation.description,
|
||||
"commands": simulation.commands,
|
||||
"prerequisites": simulation.prerequisites,
|
||||
"executed_at": simulation.executed_at.isoformat() if simulation.executed_at else None,
|
||||
"execution_result": simulation.execution_result,
|
||||
"log_source": simulation.log_source,
|
||||
"logs": simulation.logs,
|
||||
"soc_comment": simulation.soc_comment,
|
||||
"incident_number": simulation.incident_number,
|
||||
"status": simulation.status.value,
|
||||
"created_at": simulation.created_at.isoformat() if simulation.created_at else None,
|
||||
"updated_at": simulation.updated_at.isoformat() if simulation.updated_at else None,
|
||||
"created_by": serialize_user_brief(simulation.created_by) # type: ignore[arg-type]
|
||||
if simulation.created_by
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
def serialize_engagement(engagement: Engagement) -> dict[str, Any]:
|
||||
return {
|
||||
"id": engagement.id,
|
||||
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
100
backend/app/services/mitre.py
Normal file
100
backend/app/services/mitre.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""MITRE ATT&CK bundle loader and search service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Absolute path to the committed bundle.
|
||||
_BUNDLE_PATH = Path(__file__).parent.parent.parent / "data" / "mitre" / "enterprise-attack.json"
|
||||
|
||||
mitre_loaded: bool = False
|
||||
_index: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
def _extract_tactics(obj: dict[str, Any]) -> list[str]:
|
||||
phases = obj.get("kill_chain_phases") or []
|
||||
return [
|
||||
p["phase_name"]
|
||||
for p in phases
|
||||
if isinstance(p, dict) and "phase_name" in p
|
||||
]
|
||||
|
||||
|
||||
def _get_external_id(obj: dict[str, Any]) -> str | None:
|
||||
for ref in obj.get("external_references") or []:
|
||||
if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack":
|
||||
return ref.get("external_id")
|
||||
return None
|
||||
|
||||
|
||||
def load_bundle(path: Path | None = None) -> None:
|
||||
"""Load the MITRE bundle into memory. Called once at app boot."""
|
||||
global mitre_loaded, _index
|
||||
bundle_path = path or _BUNDLE_PATH
|
||||
|
||||
try:
|
||||
raw = bundle_path.read_text(encoding="utf-8")
|
||||
data = json.loads(raw)
|
||||
except FileNotFoundError:
|
||||
logger.warning("MITRE bundle not found at %s — autocomplete disabled", bundle_path)
|
||||
mitre_loaded = False
|
||||
return
|
||||
except (json.JSONDecodeError, OSError) as exc:
|
||||
logger.warning("MITRE bundle parse error: %s — autocomplete disabled", exc)
|
||||
mitre_loaded = False
|
||||
return
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
for obj in data.get("objects") or []:
|
||||
if not isinstance(obj, dict):
|
||||
continue
|
||||
if obj.get("type") != "attack-pattern":
|
||||
continue
|
||||
if obj.get("revoked") or obj.get("x_mitre_deprecated"):
|
||||
continue
|
||||
ext_id = _get_external_id(obj)
|
||||
if not ext_id:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"id": ext_id,
|
||||
"name": obj.get("name", ""),
|
||||
"tactics": _extract_tactics(obj),
|
||||
}
|
||||
)
|
||||
|
||||
_index = entries
|
||||
mitre_loaded = True
|
||||
logger.info("MITRE bundle loaded: %d techniques", len(_index))
|
||||
|
||||
|
||||
def search(query: str, limit: int = 20) -> list[dict[str, Any]]:
|
||||
"""Return up to `limit` techniques matching `query`.
|
||||
|
||||
Ranking: exact id > prefix id > substring name (case-insensitive).
|
||||
"""
|
||||
q = query.strip().upper()
|
||||
if not q:
|
||||
return []
|
||||
|
||||
exact: list[dict[str, Any]] = []
|
||||
prefix: list[dict[str, Any]] = []
|
||||
name_match: list[dict[str, Any]] = []
|
||||
|
||||
for entry in _index:
|
||||
tech_id = entry["id"].upper()
|
||||
tech_name = entry["name"].upper()
|
||||
|
||||
if tech_id == q:
|
||||
exact.append(entry)
|
||||
elif tech_id.startswith(q):
|
||||
prefix.append(entry)
|
||||
elif q in tech_name:
|
||||
name_match.append(entry)
|
||||
|
||||
combined = exact + prefix + name_match
|
||||
return combined[:limit]
|
||||
129
backend/app/services/simulation_workflow.py
Normal file
129
backend/app/services/simulation_workflow.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Simulation business logic: PATCH rules and state machine transitions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
from backend.app.extensions import db
|
||||
from backend.app.models import User
|
||||
from backend.app.models.simulation import Simulation, SimulationStatus
|
||||
|
||||
REDTEAM_FIELDS = frozenset(
|
||||
{
|
||||
"name",
|
||||
"mitre_technique_id",
|
||||
"mitre_technique_name",
|
||||
"description",
|
||||
"commands",
|
||||
"prerequisites",
|
||||
"executed_at",
|
||||
"execution_result",
|
||||
}
|
||||
)
|
||||
|
||||
SOC_FIELDS = frozenset({"log_source", "logs", "soc_comment", "incident_number"})
|
||||
|
||||
# Transitions allowed via POST /transition endpoint (manual only).
|
||||
# auto pending→in_progress is handled in apply_patch, not here.
|
||||
_ALLOWED_TRANSITIONS: dict[str, dict[str, set[str]]] = {
|
||||
"review_required": {
|
||||
"from": {"pending", "in_progress"},
|
||||
"roles": {"admin", "redteam"},
|
||||
},
|
||||
"done": {
|
||||
"from": {"review_required"},
|
||||
"roles": {"admin", "redteam", "soc"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _is_non_empty(value: Any) -> bool:
|
||||
"""Return True if value counts as "filled" for auto-transition purposes."""
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, str) and value == "":
|
||||
return False
|
||||
return not (isinstance(value, list) and len(value) == 0)
|
||||
|
||||
|
||||
def apply_patch(
|
||||
simulation: Simulation, payload: dict[str, Any], user: User
|
||||
) -> tuple[Any, int] | None:
|
||||
"""Apply a validated PATCH payload to a simulation.
|
||||
|
||||
Returns a (response, status_code) tuple on error, or None on success
|
||||
(caller is responsible for committing).
|
||||
"""
|
||||
role = user.role.value
|
||||
|
||||
if role == "soc":
|
||||
# SOC can only patch when status allows it.
|
||||
if simulation.status not in (
|
||||
SimulationStatus.REVIEW_REQUIRED,
|
||||
SimulationStatus.DONE,
|
||||
):
|
||||
return jsonify({"error": "simulation not ready for SOC review"}), 403
|
||||
|
||||
# SOC must not send redteam fields.
|
||||
redteam_keys_in_payload = REDTEAM_FIELDS & payload.keys()
|
||||
if redteam_keys_in_payload:
|
||||
return jsonify({"error": "soc cannot edit redteam fields"}), 403
|
||||
|
||||
for field in SOC_FIELDS:
|
||||
if field in payload:
|
||||
setattr(simulation, field, payload[field])
|
||||
|
||||
else:
|
||||
# admin / redteam: apply all fields present.
|
||||
redteam_keys_present = REDTEAM_FIELDS & payload.keys()
|
||||
|
||||
for field in redteam_keys_present:
|
||||
if field == "executed_at":
|
||||
val = payload["executed_at"]
|
||||
if val is None:
|
||||
simulation.executed_at = None
|
||||
else:
|
||||
if not isinstance(val, str):
|
||||
return jsonify({"error": "invalid executed_at"}), 400
|
||||
try:
|
||||
simulation.executed_at = datetime.fromisoformat(val)
|
||||
except ValueError:
|
||||
return jsonify({"error": "invalid executed_at"}), 400
|
||||
else:
|
||||
setattr(simulation, field, payload[field])
|
||||
|
||||
for field in SOC_FIELDS:
|
||||
if field in payload:
|
||||
setattr(simulation, field, payload[field])
|
||||
|
||||
# Auto-transition pending → in_progress: at least one redteam field with
|
||||
# a non-empty value in the *incoming payload*.
|
||||
if simulation.status == SimulationStatus.PENDING and any(
|
||||
_is_non_empty(payload[k]) for k in redteam_keys_present
|
||||
):
|
||||
simulation.status = SimulationStatus.IN_PROGRESS
|
||||
|
||||
simulation.updated_at = datetime.now(UTC)
|
||||
return None
|
||||
|
||||
|
||||
def transition(
|
||||
simulation: Simulation, to_status: str, user: User
|
||||
) -> tuple[Any, int] | None:
|
||||
"""Attempt a manual transition. Returns error tuple or None on success."""
|
||||
rule = _ALLOWED_TRANSITIONS.get(to_status)
|
||||
if rule is None:
|
||||
return jsonify({"error": "invalid transition"}), 409
|
||||
|
||||
if simulation.status.value not in rule["from"]:
|
||||
return jsonify({"error": "invalid transition"}), 409
|
||||
|
||||
if user.role.value not in rule["roles"]:
|
||||
return jsonify({"error": "Forbidden"}), 403
|
||||
|
||||
simulation.status = SimulationStatus(to_status)
|
||||
simulation.updated_at = datetime.now(UTC)
|
||||
db.session.commit()
|
||||
return None
|
||||
Reference in New Issue
Block a user