feat: sprint 5 — simulation templates + instantiation + nav + dropdown #8
@@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from flask import Flask, jsonify, send_from_directory
|
from flask import Flask, jsonify, send_from_directory
|
||||||
|
|
||||||
from backend.app.api import auth_bp, engagements_bp, simulations_bp, users_bp
|
from backend.app.api import auth_bp, engagements_bp, simulations_bp, templates_bp, users_bp
|
||||||
from backend.app.cli import register_cli
|
from backend.app.cli import register_cli
|
||||||
from backend.app.config import Config, TestConfig
|
from backend.app.config import Config, TestConfig
|
||||||
from backend.app.errors import register_error_handlers
|
from backend.app.errors import register_error_handlers
|
||||||
@@ -37,6 +37,7 @@ def create_app(config_object: object | None = None) -> Flask:
|
|||||||
app.register_blueprint(users_bp)
|
app.register_blueprint(users_bp)
|
||||||
app.register_blueprint(engagements_bp)
|
app.register_blueprint(engagements_bp)
|
||||||
app.register_blueprint(simulations_bp)
|
app.register_blueprint(simulations_bp)
|
||||||
|
app.register_blueprint(templates_bp)
|
||||||
|
|
||||||
from backend.app.services import mitre as mitre_svc
|
from backend.app.services import mitre as mitre_svc
|
||||||
mitre_svc.load_bundle()
|
mitre_svc.load_bundle()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
from backend.app.api.auth import auth_bp
|
from backend.app.api.auth import auth_bp
|
||||||
from backend.app.api.engagements import engagements_bp
|
from backend.app.api.engagements import engagements_bp
|
||||||
from backend.app.api.simulations import simulations_bp
|
from backend.app.api.simulations import simulations_bp
|
||||||
|
from backend.app.api.templates import templates_bp
|
||||||
from backend.app.api.users import users_bp
|
from backend.app.api.users import users_bp
|
||||||
|
|
||||||
__all__ = ["auth_bp", "users_bp", "engagements_bp", "simulations_bp"]
|
__all__ = ["auth_bp", "users_bp", "engagements_bp", "simulations_bp", "templates_bp"]
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ def create_simulation(eid: int):
|
|||||||
if not name:
|
if not name:
|
||||||
return jsonify({"error": "name is required"}), 400
|
return jsonify({"error": "name is required"}), 400
|
||||||
|
|
||||||
|
template_id = data.get("template_id")
|
||||||
sim = Simulation(
|
sim = Simulation(
|
||||||
engagement_id=eid,
|
engagement_id=eid,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -53,6 +54,19 @@ def create_simulation(eid: int):
|
|||||||
created_at=datetime.now(UTC),
|
created_at=datetime.now(UTC),
|
||||||
created_by_id=g.current_user.id,
|
created_by_id=g.current_user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if template_id is not None:
|
||||||
|
from backend.app.models.simulation_template import SimulationTemplate
|
||||||
|
|
||||||
|
tmpl = db.session.get(SimulationTemplate, template_id)
|
||||||
|
if tmpl is None:
|
||||||
|
return jsonify({"error": "Template not found"}), 404
|
||||||
|
sim.description = tmpl.description
|
||||||
|
sim.commands = tmpl.commands
|
||||||
|
sim.prerequisites = tmpl.prerequisites
|
||||||
|
sim.techniques = list(tmpl.techniques or [])
|
||||||
|
sim.tactic_ids = list(tmpl.tactic_ids or [])
|
||||||
|
|
||||||
db.session.add(sim)
|
db.session.add(sim)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return jsonify(serialize_simulation(sim)), 201
|
return jsonify(serialize_simulation(sim)), 201
|
||||||
|
|||||||
143
backend/app/api/templates.py
Normal file
143
backend/app/api/templates.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""SimulationTemplate CRUD endpoints — admin and redteam only."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import sqlalchemy.exc
|
||||||
|
from flask import Blueprint, g, jsonify, request
|
||||||
|
|
||||||
|
from backend.app.auth import role_required
|
||||||
|
from backend.app.extensions import db
|
||||||
|
from backend.app.models.simulation_template import SimulationTemplate
|
||||||
|
from backend.app.serializers import serialize_template
|
||||||
|
from backend.app.services import mitre as mitre_svc
|
||||||
|
from backend.app.services.simulation_workflow import (
|
||||||
|
_resolve_tactic_ids,
|
||||||
|
_resolve_technique_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
templates_bp = Blueprint("templates", __name__)
|
||||||
|
|
||||||
|
_MUTABLE_FIELDS = {"name", "description", "commands", "prerequisites", "technique_ids", "tactic_ids"}
|
||||||
|
|
||||||
|
|
||||||
|
@templates_bp.get("/api/templates")
|
||||||
|
@role_required("admin", "redteam")
|
||||||
|
def list_templates():
|
||||||
|
items = SimulationTemplate.query.order_by(SimulationTemplate.name).all()
|
||||||
|
return jsonify([serialize_template(t) for t in items]), 200
|
||||||
|
|
||||||
|
|
||||||
|
@templates_bp.post("/api/templates")
|
||||||
|
@role_required("admin", "redteam")
|
||||||
|
def create_template():
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
name = (data.get("name") or "").strip()
|
||||||
|
if not name:
|
||||||
|
return jsonify({"error": "name is required"}), 400
|
||||||
|
|
||||||
|
techniques: list[dict] = []
|
||||||
|
tactic_ids_val: list[str] = []
|
||||||
|
|
||||||
|
if "technique_ids" in data:
|
||||||
|
if not mitre_svc.mitre_loaded:
|
||||||
|
return jsonify({"error": "mitre bundle not loaded"}), 503
|
||||||
|
resolved, err = _resolve_technique_ids(data["technique_ids"])
|
||||||
|
if err is not None:
|
||||||
|
return err
|
||||||
|
techniques = resolved or []
|
||||||
|
|
||||||
|
if "tactic_ids" in data:
|
||||||
|
resolved_ta, err = _resolve_tactic_ids(data["tactic_ids"])
|
||||||
|
if err is not None:
|
||||||
|
return err
|
||||||
|
tactic_ids_val = resolved_ta or []
|
||||||
|
|
||||||
|
tmpl = SimulationTemplate(
|
||||||
|
name=name,
|
||||||
|
description=data.get("description"),
|
||||||
|
commands=data.get("commands"),
|
||||||
|
prerequisites=data.get("prerequisites"),
|
||||||
|
techniques=techniques,
|
||||||
|
tactic_ids=tactic_ids_val,
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
created_by_id=g.current_user.id,
|
||||||
|
)
|
||||||
|
db.session.add(tmpl)
|
||||||
|
try:
|
||||||
|
db.session.commit()
|
||||||
|
except sqlalchemy.exc.IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
return jsonify({"error": "template name already exists"}), 409
|
||||||
|
|
||||||
|
return jsonify(serialize_template(tmpl)), 201
|
||||||
|
|
||||||
|
|
||||||
|
@templates_bp.get("/api/templates/<int:tid>")
|
||||||
|
@role_required("admin", "redteam")
|
||||||
|
def get_template(tid: int):
|
||||||
|
tmpl = db.session.get(SimulationTemplate, tid)
|
||||||
|
if tmpl is None:
|
||||||
|
return jsonify({"error": "Template not found"}), 404
|
||||||
|
return jsonify(serialize_template(tmpl)), 200
|
||||||
|
|
||||||
|
|
||||||
|
@templates_bp.patch("/api/templates/<int:tid>")
|
||||||
|
@role_required("admin", "redteam")
|
||||||
|
def update_template(tid: int):
|
||||||
|
tmpl = db.session.get(SimulationTemplate, tid)
|
||||||
|
if tmpl is None:
|
||||||
|
return jsonify({"error": "Template not found"}), 404
|
||||||
|
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
unknown = set(data.keys()) - _MUTABLE_FIELDS
|
||||||
|
if unknown:
|
||||||
|
return jsonify({"error": f"unknown fields: {sorted(unknown)}"}), 400
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return jsonify(serialize_template(tmpl)), 200
|
||||||
|
|
||||||
|
if "name" in data:
|
||||||
|
name = (data["name"] or "").strip()
|
||||||
|
if not name:
|
||||||
|
return jsonify({"error": "name cannot be empty"}), 400
|
||||||
|
tmpl.name = name
|
||||||
|
|
||||||
|
for field in ("description", "commands", "prerequisites"):
|
||||||
|
if field in data:
|
||||||
|
setattr(tmpl, field, data[field])
|
||||||
|
|
||||||
|
if "technique_ids" in data:
|
||||||
|
if not mitre_svc.mitre_loaded:
|
||||||
|
return jsonify({"error": "mitre bundle not loaded"}), 503
|
||||||
|
resolved, err = _resolve_technique_ids(data["technique_ids"])
|
||||||
|
if err is not None:
|
||||||
|
return err
|
||||||
|
tmpl.techniques = resolved
|
||||||
|
|
||||||
|
if "tactic_ids" in data:
|
||||||
|
resolved_ta, err = _resolve_tactic_ids(data["tactic_ids"])
|
||||||
|
if err is not None:
|
||||||
|
return err
|
||||||
|
tmpl.tactic_ids = resolved_ta
|
||||||
|
|
||||||
|
tmpl.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.session.commit()
|
||||||
|
except sqlalchemy.exc.IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
return jsonify({"error": "template name already exists"}), 409
|
||||||
|
|
||||||
|
return jsonify(serialize_template(tmpl)), 200
|
||||||
|
|
||||||
|
|
||||||
|
@templates_bp.delete("/api/templates/<int:tid>")
|
||||||
|
@role_required("admin", "redteam")
|
||||||
|
def delete_template(tid: int):
|
||||||
|
tmpl = db.session.get(SimulationTemplate, tid)
|
||||||
|
if tmpl is None:
|
||||||
|
return jsonify({"error": "Template not found"}), 404
|
||||||
|
db.session.delete(tmpl)
|
||||||
|
db.session.commit()
|
||||||
|
return "", 204
|
||||||
@@ -1,6 +1,15 @@
|
|||||||
"""SQLAlchemy models."""
|
"""SQLAlchemy models."""
|
||||||
from backend.app.models.engagement import Engagement, EngagementStatus
|
from backend.app.models.engagement import Engagement, EngagementStatus
|
||||||
from backend.app.models.simulation import Simulation, SimulationStatus
|
from backend.app.models.simulation import Simulation, SimulationStatus
|
||||||
|
from backend.app.models.simulation_template import SimulationTemplate
|
||||||
from backend.app.models.user import User, UserRole
|
from backend.app.models.user import User, UserRole
|
||||||
|
|
||||||
__all__ = ["User", "UserRole", "Engagement", "EngagementStatus", "Simulation", "SimulationStatus"]
|
__all__ = [
|
||||||
|
"User",
|
||||||
|
"UserRole",
|
||||||
|
"Engagement",
|
||||||
|
"EngagementStatus",
|
||||||
|
"Simulation",
|
||||||
|
"SimulationStatus",
|
||||||
|
"SimulationTemplate",
|
||||||
|
]
|
||||||
|
|||||||
32
backend/app/models/simulation_template.py
Normal file
32
backend/app/models/simulation_template.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""SimulationTemplate model."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from backend.app.extensions import db
|
||||||
|
|
||||||
|
|
||||||
|
class SimulationTemplate(db.Model): # type: ignore[name-defined]
|
||||||
|
__tablename__ = "simulation_templates"
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
name = db.Column(db.String(255), nullable=False, unique=True)
|
||||||
|
description = db.Column(db.Text, nullable=True)
|
||||||
|
commands = db.Column(db.Text, nullable=True)
|
||||||
|
prerequisites = db.Column(db.Text, nullable=True)
|
||||||
|
techniques = db.Column(db.JSON, nullable=False, default=list)
|
||||||
|
tactic_ids = db.Column(db.JSON, nullable=False, default=list)
|
||||||
|
created_at = db.Column(
|
||||||
|
db.DateTime, nullable=False, default=lambda: datetime.now(UTC)
|
||||||
|
)
|
||||||
|
updated_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
created_by_id = db.Column(
|
||||||
|
db.Integer,
|
||||||
|
db.ForeignKey("users.id", ondelete="RESTRICT"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
created_by = db.relationship("User", lazy="joined")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<SimulationTemplate {self.id} {self.name!r}>"
|
||||||
@@ -5,6 +5,7 @@ from typing import Any
|
|||||||
|
|
||||||
from backend.app.models import Engagement, User
|
from backend.app.models import Engagement, User
|
||||||
from backend.app.models.simulation import Simulation
|
from backend.app.models.simulation import Simulation
|
||||||
|
from backend.app.models.simulation_template import SimulationTemplate
|
||||||
|
|
||||||
|
|
||||||
def serialize_user(user: User) -> dict[str, Any]:
|
def serialize_user(user: User) -> dict[str, Any]:
|
||||||
@@ -69,6 +70,23 @@ def serialize_simulation(simulation: Simulation) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_template(t: SimulationTemplate) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": t.id,
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"commands": t.commands,
|
||||||
|
"prerequisites": t.prerequisites,
|
||||||
|
"techniques": _enrich_techniques(t.techniques or []),
|
||||||
|
"tactics": _enrich_tactics(t.tactic_ids or []),
|
||||||
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||||
|
"updated_at": t.updated_at.isoformat() if t.updated_at else None,
|
||||||
|
"created_by": serialize_user_brief(t.created_by) # type: ignore[arg-type]
|
||||||
|
if t.created_by
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_engagement(engagement: Engagement) -> dict[str, Any]:
|
def serialize_engagement(engagement: Engagement) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"id": engagement.id,
|
"id": engagement.id,
|
||||||
|
|||||||
40
backend/migrations/versions/0005_simulation_templates.py
Normal file
40
backend/migrations/versions/0005_simulation_templates.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""create simulation_templates table
|
||||||
|
|
||||||
|
Revision ID: 0005
|
||||||
|
Revises: 0004
|
||||||
|
Create Date: 2026-05-28 00:00:00.000000
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "0005"
|
||||||
|
down_revision = "0004"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"simulation_templates",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True),
|
||||||
|
sa.Column("name", sa.String(length=255), nullable=False, unique=True),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("commands", sa.Text(), nullable=True),
|
||||||
|
sa.Column("prerequisites", sa.Text(), nullable=True),
|
||||||
|
sa.Column("techniques", sa.JSON(), nullable=False, server_default=sa.text("'[]'")),
|
||||||
|
sa.Column("tactic_ids", sa.JSON(), nullable=False, server_default=sa.text("'[]'")),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_by_id",
|
||||||
|
sa.Integer(),
|
||||||
|
sa.ForeignKey("users.id", ondelete="RESTRICT"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_simulation_templates_name", "simulation_templates", ["name"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_simulation_templates_name", "simulation_templates")
|
||||||
|
op.drop_table("simulation_templates")
|
||||||
239
backend/tests/test_simulation_templates_crud.py
Normal file
239
backend/tests/test_simulation_templates_crud.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""SimulationTemplate CRUD: list, create, get, patch, delete + RBAC + dedup."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from flask.testing import FlaskClient
|
||||||
|
|
||||||
|
from backend.app.extensions import db
|
||||||
|
from backend.app.models import User
|
||||||
|
from backend.app.models.simulation_template import SimulationTemplate
|
||||||
|
from backend.tests.conftest import auth_headers as _h # noqa: E402
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_template(client: FlaskClient, token: str, **kw) -> dict:
|
||||||
|
payload = {"name": "Template Alpha", **kw}
|
||||||
|
resp = client.post("/api/templates", headers=_h(token), json=payload)
|
||||||
|
assert resp.status_code == 201, resp.get_json()
|
||||||
|
return resp.get_json()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# List
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_templates_empty(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
resp = client.get("/api/templates", headers=_h(admin_token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.get_json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_templates_soc_forbidden(client: FlaskClient, soc_token: str) -> None:
|
||||||
|
resp = client.get("/api/templates", headers=_h(soc_token))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_templates_unauthenticated(client: FlaskClient) -> None:
|
||||||
|
resp = client.get("/api/templates")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Create
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_template_as_admin(
|
||||||
|
client: FlaskClient, admin_user: User, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
body = _make_template(
|
||||||
|
client,
|
||||||
|
admin_token,
|
||||||
|
description="desc",
|
||||||
|
commands="cmd",
|
||||||
|
prerequisites="prereq",
|
||||||
|
)
|
||||||
|
assert body["name"] == "Template Alpha"
|
||||||
|
assert body["description"] == "desc"
|
||||||
|
assert body["commands"] == "cmd"
|
||||||
|
assert body["prerequisites"] == "prereq"
|
||||||
|
assert body["techniques"] == []
|
||||||
|
assert body["tactics"] == []
|
||||||
|
assert body["created_by"] == {"id": admin_user.id, "username": "admin1"}
|
||||||
|
assert body["id"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_template_as_redteam(
|
||||||
|
client: FlaskClient, redteam_user: User, redteam_token: str
|
||||||
|
) -> None:
|
||||||
|
body = _make_template(client, redteam_token)
|
||||||
|
assert body["created_by"]["username"] == "redteam1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_template_soc_forbidden(client: FlaskClient, soc_token: str) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/templates", headers=_h(soc_token), json={"name": "T"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_template_missing_name(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
resp = client.post("/api/templates", headers=_h(admin_token), json={})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "name" in resp.get_json()["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_template_duplicate_name_409(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
_make_template(client, admin_token)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/templates", headers=_h(admin_token), json={"name": "Template Alpha"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
assert "already exists" in resp.get_json()["error"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Get single
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_template(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.get(f"/api/templates/{created['id']}", headers=_h(admin_token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.get_json()["id"] == created["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_template_not_found(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
resp = client.get("/api/templates/9999", headers=_h(admin_token))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_template_soc_forbidden(
|
||||||
|
client: FlaskClient, admin_token: str, soc_token: str
|
||||||
|
) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.get(f"/api/templates/{created['id']}", headers=_h(soc_token))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Patch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_template_name(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/templates/{created['id']}",
|
||||||
|
headers=_h(admin_token),
|
||||||
|
json={"name": "Renamed"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.get_json()["name"] == "Renamed"
|
||||||
|
assert resp.get_json()["updated_at"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_template_empty_name_rejected(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/templates/{created['id']}",
|
||||||
|
headers=_h(admin_token),
|
||||||
|
json={"name": ""},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_template_unknown_field_rejected(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/templates/{created['id']}",
|
||||||
|
headers=_h(admin_token),
|
||||||
|
json={"bogus_field": "x"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "unknown fields" in resp.get_json()["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_template_duplicate_name_409(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
_make_template(client, admin_token, name="T1")
|
||||||
|
t2 = _make_template(client, admin_token, name="T2")
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/templates/{t2['id']}",
|
||||||
|
headers=_h(admin_token),
|
||||||
|
json={"name": "T1"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_template_soc_forbidden(
|
||||||
|
client: FlaskClient, admin_token: str, soc_token: str
|
||||||
|
) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/templates/{created['id']}",
|
||||||
|
headers=_h(soc_token),
|
||||||
|
json={"name": "X"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_template_not_found(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
resp = client.patch(
|
||||||
|
"/api/templates/9999", headers=_h(admin_token), json={"name": "X"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Delete
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_template(
|
||||||
|
client: FlaskClient, app, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.delete(f"/api/templates/{created['id']}", headers=_h(admin_token))
|
||||||
|
assert resp.status_code == 204
|
||||||
|
with app.app_context():
|
||||||
|
assert db.session.get(SimulationTemplate, created["id"]) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_template_not_found(client: FlaskClient, admin_token: str) -> None:
|
||||||
|
resp = client.delete("/api/templates/9999", headers=_h(admin_token))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_template_soc_forbidden(
|
||||||
|
client: FlaskClient, admin_token: str, soc_token: str
|
||||||
|
) -> None:
|
||||||
|
created = _make_template(client, admin_token)
|
||||||
|
resp = client.delete(f"/api/templates/{created['id']}", headers=_h(soc_token))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# List returns ordered by name
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_templates_ordered_by_name(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
for name in ("Zebra", "Alpha", "Midpoint"):
|
||||||
|
_make_template(client, admin_token, name=name)
|
||||||
|
body = client.get("/api/templates", headers=_h(admin_token)).get_json()
|
||||||
|
names = [t["name"] for t in body]
|
||||||
|
assert names == sorted(names)
|
||||||
195
backend/tests/test_simulations_from_template.py
Normal file
195
backend/tests/test_simulations_from_template.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Tests for creating simulations from a template (POST /api/engagements/<eid>/simulations)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from alembic.operations import Operations
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
from flask.testing import FlaskClient
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
from backend.tests.conftest import auth_headers as _h
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_engagement(client: FlaskClient, token: str) -> dict:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/engagements",
|
||||||
|
headers=_h(token),
|
||||||
|
json={"name": "Op Bravo", "start_date": "2026-06-01"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201, resp.get_json()
|
||||||
|
return resp.get_json()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_template(client: FlaskClient, token: str, **kw) -> dict:
|
||||||
|
payload = {"name": "Base Template", **kw}
|
||||||
|
resp = client.post("/api/templates", headers=_h(token), json=payload)
|
||||||
|
assert resp.status_code == 201, resp.get_json()
|
||||||
|
return resp.get_json()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sim(client: FlaskClient, token: str, eid: int, **kw) -> dict:
|
||||||
|
payload = {"name": "Sim From Template", **kw}
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/engagements/{eid}/simulations", headers=_h(token), json=payload
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201, resp.get_json()
|
||||||
|
return resp.get_json()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Instantiation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_simulation_from_template_copies_fields(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
eng = _make_engagement(client, admin_token)
|
||||||
|
tmpl = _make_template(
|
||||||
|
client,
|
||||||
|
admin_token,
|
||||||
|
description="template desc",
|
||||||
|
commands="template cmd",
|
||||||
|
prerequisites="template prereq",
|
||||||
|
)
|
||||||
|
sim = _make_sim(client, admin_token, eng["id"], template_id=tmpl["id"])
|
||||||
|
|
||||||
|
assert sim["description"] == "template desc"
|
||||||
|
assert sim["commands"] == "template cmd"
|
||||||
|
assert sim["prerequisites"] == "template prereq"
|
||||||
|
assert sim["techniques"] == []
|
||||||
|
assert sim["tactics"] == []
|
||||||
|
assert sim["status"] == "pending"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_simulation_name_overrides_template(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
eng = _make_engagement(client, admin_token)
|
||||||
|
tmpl = _make_template(client, admin_token)
|
||||||
|
sim = _make_sim(
|
||||||
|
client, admin_token, eng["id"], name="Custom Name", template_id=tmpl["id"]
|
||||||
|
)
|
||||||
|
assert sim["name"] == "Custom Name"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_simulation_template_not_found(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
eng = _make_engagement(client, admin_token)
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/engagements/{eng['id']}/simulations",
|
||||||
|
headers=_h(admin_token),
|
||||||
|
json={"name": "S", "template_id": 9999},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
assert "Template not found" in resp.get_json()["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_simulation_without_template_unaffected(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
eng = _make_engagement(client, admin_token)
|
||||||
|
sim = _make_sim(client, admin_token, eng["id"])
|
||||||
|
assert sim["description"] is None
|
||||||
|
assert sim["commands"] is None
|
||||||
|
assert sim["prerequisites"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_simulation_from_template_status_is_pending(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
eng = _make_engagement(client, admin_token)
|
||||||
|
tmpl = _make_template(client, admin_token)
|
||||||
|
sim = _make_sim(client, admin_token, eng["id"], template_id=tmpl["id"])
|
||||||
|
assert sim["status"] == "pending"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_template_does_not_cascade_to_simulations(
|
||||||
|
client: FlaskClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
eng = _make_engagement(client, admin_token)
|
||||||
|
tmpl = _make_template(client, admin_token)
|
||||||
|
sim = _make_sim(client, admin_token, eng["id"], template_id=tmpl["id"])
|
||||||
|
sid = sim["id"]
|
||||||
|
|
||||||
|
# Delete the template.
|
||||||
|
del_resp = client.delete(
|
||||||
|
f"/api/templates/{tmpl['id']}", headers=_h(admin_token)
|
||||||
|
)
|
||||||
|
assert del_resp.status_code == 204
|
||||||
|
|
||||||
|
# Simulation must still be retrievable.
|
||||||
|
get_resp = client.get(f"/api/simulations/{sid}", headers=_h(admin_token))
|
||||||
|
assert get_resp.status_code == 200
|
||||||
|
assert get_resp.get_json()["id"] == sid
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Migration round-trip
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_0005_round_trip() -> None:
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
migration_file = (
|
||||||
|
Path(__file__).parent.parent
|
||||||
|
/ "migrations"
|
||||||
|
/ "versions"
|
||||||
|
/ "0005_simulation_templates.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location("m0005", migration_file)
|
||||||
|
assert spec is not None
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
ctx = MigrationContext.configure(conn)
|
||||||
|
import alembic.op as op_module
|
||||||
|
|
||||||
|
op_module._proxy = Operations(ctx) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Create users table (FK dependency).
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE TABLE users ("
|
||||||
|
"id INTEGER PRIMARY KEY, "
|
||||||
|
"username TEXT NOT NULL, "
|
||||||
|
"password_hash TEXT NOT NULL, "
|
||||||
|
"role TEXT NOT NULL DEFAULT 'redteam', "
|
||||||
|
"created_at DATETIME"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
module.upgrade()
|
||||||
|
|
||||||
|
tables_after = conn.execute(
|
||||||
|
text("SELECT name FROM sqlite_master WHERE type='table'")
|
||||||
|
).fetchall()
|
||||||
|
table_names = {r[0] for r in tables_after}
|
||||||
|
assert "simulation_templates" in table_names
|
||||||
|
|
||||||
|
cols = conn.execute(
|
||||||
|
text("PRAGMA table_info(simulation_templates)")
|
||||||
|
).fetchall()
|
||||||
|
col_names = {c[1] for c in cols}
|
||||||
|
for expected in ("id", "name", "techniques", "tactic_ids", "created_by_id"):
|
||||||
|
assert expected in col_names, f"missing column: {expected}"
|
||||||
|
|
||||||
|
module.downgrade()
|
||||||
|
|
||||||
|
tables_after_down = conn.execute(
|
||||||
|
text("SELECT name FROM sqlite_master WHERE type='table'")
|
||||||
|
).fetchall()
|
||||||
|
table_names_down = {r[0] for r in tables_after_down}
|
||||||
|
assert "simulation_templates" not in table_names_down
|
||||||
Reference in New Issue
Block a user