diff --git a/backend/app/__init__.py b/backend/app/__init__.py index e747eb0..350a0e2 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from flask import Flask, jsonify, send_from_directory -from backend.app.api import auth_bp, engagements_bp, simulations_bp, users_bp +from backend.app.api import auth_bp, engagements_bp, simulations_bp, templates_bp, users_bp from backend.app.cli import register_cli from backend.app.config import Config, TestConfig from backend.app.errors import register_error_handlers @@ -37,6 +37,7 @@ def create_app(config_object: object | None = None) -> Flask: app.register_blueprint(users_bp) app.register_blueprint(engagements_bp) app.register_blueprint(simulations_bp) + app.register_blueprint(templates_bp) from backend.app.services import mitre as mitre_svc mitre_svc.load_bundle() diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index 30dd1d2..780821a 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -2,6 +2,7 @@ from backend.app.api.auth import auth_bp from backend.app.api.engagements import engagements_bp from backend.app.api.simulations import simulations_bp +from backend.app.api.templates import templates_bp from backend.app.api.users import users_bp -__all__ = ["auth_bp", "users_bp", "engagements_bp", "simulations_bp"] +__all__ = ["auth_bp", "users_bp", "engagements_bp", "simulations_bp", "templates_bp"] diff --git a/backend/app/api/simulations.py b/backend/app/api/simulations.py index d5a2f2a..a384577 100644 --- a/backend/app/api/simulations.py +++ b/backend/app/api/simulations.py @@ -46,6 +46,7 @@ def create_simulation(eid: int): if not name: return jsonify({"error": "name is required"}), 400 + template_id = data.get("template_id") sim = Simulation( engagement_id=eid, name=name, @@ -53,6 +54,19 @@ def create_simulation(eid: int): created_at=datetime.now(UTC), created_by_id=g.current_user.id, ) + + if template_id is not None: + from backend.app.models.simulation_template import SimulationTemplate + + tmpl = db.session.get(SimulationTemplate, template_id) + if tmpl is None: + return jsonify({"error": "Template not found"}), 404 + sim.description = tmpl.description + sim.commands = tmpl.commands + sim.prerequisites = tmpl.prerequisites + sim.techniques = list(tmpl.techniques or []) + sim.tactic_ids = list(tmpl.tactic_ids or []) + db.session.add(sim) db.session.commit() return jsonify(serialize_simulation(sim)), 201 diff --git a/backend/app/api/templates.py b/backend/app/api/templates.py new file mode 100644 index 0000000..540a31d --- /dev/null +++ b/backend/app/api/templates.py @@ -0,0 +1,143 @@ +"""SimulationTemplate CRUD endpoints — admin and redteam only.""" +from __future__ import annotations + +from datetime import UTC, datetime + +import sqlalchemy.exc +from flask import Blueprint, g, jsonify, request + +from backend.app.auth import role_required +from backend.app.extensions import db +from backend.app.models.simulation_template import SimulationTemplate +from backend.app.serializers import serialize_template +from backend.app.services import mitre as mitre_svc +from backend.app.services.simulation_workflow import ( + _resolve_tactic_ids, + _resolve_technique_ids, +) + +templates_bp = Blueprint("templates", __name__) + +_MUTABLE_FIELDS = {"name", "description", "commands", "prerequisites", "technique_ids", "tactic_ids"} + + +@templates_bp.get("/api/templates") +@role_required("admin", "redteam") +def list_templates(): + items = SimulationTemplate.query.order_by(SimulationTemplate.name).all() + return jsonify([serialize_template(t) for t in items]), 200 + + +@templates_bp.post("/api/templates") +@role_required("admin", "redteam") +def create_template(): + data = request.get_json(silent=True) or {} + name = (data.get("name") or "").strip() + if not name: + return jsonify({"error": "name is required"}), 400 + + techniques: list[dict] = [] + tactic_ids_val: list[str] = [] + + if "technique_ids" in data: + if not mitre_svc.mitre_loaded: + return jsonify({"error": "mitre bundle not loaded"}), 503 + resolved, err = _resolve_technique_ids(data["technique_ids"]) + if err is not None: + return err + techniques = resolved or [] + + if "tactic_ids" in data: + resolved_ta, err = _resolve_tactic_ids(data["tactic_ids"]) + if err is not None: + return err + tactic_ids_val = resolved_ta or [] + + tmpl = SimulationTemplate( + name=name, + description=data.get("description"), + commands=data.get("commands"), + prerequisites=data.get("prerequisites"), + techniques=techniques, + tactic_ids=tactic_ids_val, + created_at=datetime.now(UTC), + created_by_id=g.current_user.id, + ) + db.session.add(tmpl) + try: + db.session.commit() + except sqlalchemy.exc.IntegrityError: + db.session.rollback() + return jsonify({"error": "template name already exists"}), 409 + + return jsonify(serialize_template(tmpl)), 201 + + +@templates_bp.get("/api/templates/") +@role_required("admin", "redteam") +def get_template(tid: int): + tmpl = db.session.get(SimulationTemplate, tid) + if tmpl is None: + return jsonify({"error": "Template not found"}), 404 + return jsonify(serialize_template(tmpl)), 200 + + +@templates_bp.patch("/api/templates/") +@role_required("admin", "redteam") +def update_template(tid: int): + tmpl = db.session.get(SimulationTemplate, tid) + if tmpl is None: + return jsonify({"error": "Template not found"}), 404 + + data = request.get_json(silent=True) or {} + unknown = set(data.keys()) - _MUTABLE_FIELDS + if unknown: + return jsonify({"error": f"unknown fields: {sorted(unknown)}"}), 400 + + if not data: + return jsonify(serialize_template(tmpl)), 200 + + if "name" in data: + name = (data["name"] or "").strip() + if not name: + return jsonify({"error": "name cannot be empty"}), 400 + tmpl.name = name + + for field in ("description", "commands", "prerequisites"): + if field in data: + setattr(tmpl, field, data[field]) + + if "technique_ids" in data: + if not mitre_svc.mitre_loaded: + return jsonify({"error": "mitre bundle not loaded"}), 503 + resolved, err = _resolve_technique_ids(data["technique_ids"]) + if err is not None: + return err + tmpl.techniques = resolved + + if "tactic_ids" in data: + resolved_ta, err = _resolve_tactic_ids(data["tactic_ids"]) + if err is not None: + return err + tmpl.tactic_ids = resolved_ta + + tmpl.updated_at = datetime.now(UTC) + + try: + db.session.commit() + except sqlalchemy.exc.IntegrityError: + db.session.rollback() + return jsonify({"error": "template name already exists"}), 409 + + return jsonify(serialize_template(tmpl)), 200 + + +@templates_bp.delete("/api/templates/") +@role_required("admin", "redteam") +def delete_template(tid: int): + tmpl = db.session.get(SimulationTemplate, tid) + if tmpl is None: + return jsonify({"error": "Template not found"}), 404 + db.session.delete(tmpl) + db.session.commit() + return "", 204 diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ab026a8..e432347 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,6 +1,15 @@ """SQLAlchemy models.""" from backend.app.models.engagement import Engagement, EngagementStatus from backend.app.models.simulation import Simulation, SimulationStatus +from backend.app.models.simulation_template import SimulationTemplate from backend.app.models.user import User, UserRole -__all__ = ["User", "UserRole", "Engagement", "EngagementStatus", "Simulation", "SimulationStatus"] +__all__ = [ + "User", + "UserRole", + "Engagement", + "EngagementStatus", + "Simulation", + "SimulationStatus", + "SimulationTemplate", +] diff --git a/backend/app/models/simulation_template.py b/backend/app/models/simulation_template.py new file mode 100644 index 0000000..5bf126a --- /dev/null +++ b/backend/app/models/simulation_template.py @@ -0,0 +1,32 @@ +"""SimulationTemplate model.""" +from __future__ import annotations + +from datetime import UTC, datetime + +from backend.app.extensions import db + + +class SimulationTemplate(db.Model): # type: ignore[name-defined] + __tablename__ = "simulation_templates" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(255), nullable=False, unique=True) + description = db.Column(db.Text, nullable=True) + commands = db.Column(db.Text, nullable=True) + prerequisites = db.Column(db.Text, nullable=True) + techniques = db.Column(db.JSON, nullable=False, default=list) + tactic_ids = db.Column(db.JSON, nullable=False, default=list) + created_at = db.Column( + db.DateTime, nullable=False, default=lambda: datetime.now(UTC) + ) + updated_at = db.Column(db.DateTime, nullable=True) + created_by_id = db.Column( + db.Integer, + db.ForeignKey("users.id", ondelete="RESTRICT"), + nullable=False, + ) + + created_by = db.relationship("User", lazy="joined") + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/serializers.py b/backend/app/serializers.py index 41bf4d6..3f3a8dd 100644 --- a/backend/app/serializers.py +++ b/backend/app/serializers.py @@ -5,6 +5,7 @@ from typing import Any from backend.app.models import Engagement, User from backend.app.models.simulation import Simulation +from backend.app.models.simulation_template import SimulationTemplate def serialize_user(user: User) -> dict[str, Any]: @@ -69,6 +70,23 @@ def serialize_simulation(simulation: Simulation) -> dict[str, Any]: } +def serialize_template(t: SimulationTemplate) -> dict[str, Any]: + return { + "id": t.id, + "name": t.name, + "description": t.description, + "commands": t.commands, + "prerequisites": t.prerequisites, + "techniques": _enrich_techniques(t.techniques or []), + "tactics": _enrich_tactics(t.tactic_ids or []), + "created_at": t.created_at.isoformat() if t.created_at else None, + "updated_at": t.updated_at.isoformat() if t.updated_at else None, + "created_by": serialize_user_brief(t.created_by) # type: ignore[arg-type] + if t.created_by + else None, + } + + def serialize_engagement(engagement: Engagement) -> dict[str, Any]: return { "id": engagement.id, diff --git a/backend/migrations/versions/0005_simulation_templates.py b/backend/migrations/versions/0005_simulation_templates.py new file mode 100644 index 0000000..c392357 --- /dev/null +++ b/backend/migrations/versions/0005_simulation_templates.py @@ -0,0 +1,40 @@ +"""create simulation_templates table + +Revision ID: 0005 +Revises: 0004 +Create Date: 2026-05-28 00:00:00.000000 +""" +import sqlalchemy as sa +from alembic import op + +revision = "0005" +down_revision = "0004" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "simulation_templates", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("name", sa.String(length=255), nullable=False, unique=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("commands", sa.Text(), nullable=True), + sa.Column("prerequisites", sa.Text(), nullable=True), + sa.Column("techniques", sa.JSON(), nullable=False, server_default=sa.text("'[]'")), + sa.Column("tactic_ids", sa.JSON(), nullable=False, server_default=sa.text("'[]'")), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column( + "created_by_id", + sa.Integer(), + sa.ForeignKey("users.id", ondelete="RESTRICT"), + nullable=False, + ), + ) + op.create_index("ix_simulation_templates_name", "simulation_templates", ["name"]) + + +def downgrade() -> None: + op.drop_index("ix_simulation_templates_name", "simulation_templates") + op.drop_table("simulation_templates") diff --git a/backend/tests/test_simulation_templates_crud.py b/backend/tests/test_simulation_templates_crud.py new file mode 100644 index 0000000..eb6dbb3 --- /dev/null +++ b/backend/tests/test_simulation_templates_crud.py @@ -0,0 +1,239 @@ +"""SimulationTemplate CRUD: list, create, get, patch, delete + RBAC + dedup.""" +from __future__ import annotations + +from flask.testing import FlaskClient + +from backend.app.extensions import db +from backend.app.models import User +from backend.app.models.simulation_template import SimulationTemplate +from backend.tests.conftest import auth_headers as _h # noqa: E402 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_template(client: FlaskClient, token: str, **kw) -> dict: + payload = {"name": "Template Alpha", **kw} + resp = client.post("/api/templates", headers=_h(token), json=payload) + assert resp.status_code == 201, resp.get_json() + return resp.get_json() + + +# --------------------------------------------------------------------------- +# List +# --------------------------------------------------------------------------- + + +def test_list_templates_empty(client: FlaskClient, admin_token: str) -> None: + resp = client.get("/api/templates", headers=_h(admin_token)) + assert resp.status_code == 200 + assert resp.get_json() == [] + + +def test_list_templates_soc_forbidden(client: FlaskClient, soc_token: str) -> None: + resp = client.get("/api/templates", headers=_h(soc_token)) + assert resp.status_code == 403 + + +def test_list_templates_unauthenticated(client: FlaskClient) -> None: + resp = client.get("/api/templates") + assert resp.status_code == 401 + + +# --------------------------------------------------------------------------- +# Create +# --------------------------------------------------------------------------- + + +def test_create_template_as_admin( + client: FlaskClient, admin_user: User, admin_token: str +) -> None: + body = _make_template( + client, + admin_token, + description="desc", + commands="cmd", + prerequisites="prereq", + ) + assert body["name"] == "Template Alpha" + assert body["description"] == "desc" + assert body["commands"] == "cmd" + assert body["prerequisites"] == "prereq" + assert body["techniques"] == [] + assert body["tactics"] == [] + assert body["created_by"] == {"id": admin_user.id, "username": "admin1"} + assert body["id"] is not None + + +def test_create_template_as_redteam( + client: FlaskClient, redteam_user: User, redteam_token: str +) -> None: + body = _make_template(client, redteam_token) + assert body["created_by"]["username"] == "redteam1" + + +def test_create_template_soc_forbidden(client: FlaskClient, soc_token: str) -> None: + resp = client.post( + "/api/templates", headers=_h(soc_token), json={"name": "T"} + ) + assert resp.status_code == 403 + + +def test_create_template_missing_name(client: FlaskClient, admin_token: str) -> None: + resp = client.post("/api/templates", headers=_h(admin_token), json={}) + assert resp.status_code == 400 + assert "name" in resp.get_json()["error"] + + +def test_create_template_duplicate_name_409( + client: FlaskClient, admin_token: str +) -> None: + _make_template(client, admin_token) + resp = client.post( + "/api/templates", headers=_h(admin_token), json={"name": "Template Alpha"} + ) + assert resp.status_code == 409 + assert "already exists" in resp.get_json()["error"] + + +# --------------------------------------------------------------------------- +# Get single +# --------------------------------------------------------------------------- + + +def test_get_template(client: FlaskClient, admin_token: str) -> None: + created = _make_template(client, admin_token) + resp = client.get(f"/api/templates/{created['id']}", headers=_h(admin_token)) + assert resp.status_code == 200 + assert resp.get_json()["id"] == created["id"] + + +def test_get_template_not_found(client: FlaskClient, admin_token: str) -> None: + resp = client.get("/api/templates/9999", headers=_h(admin_token)) + assert resp.status_code == 404 + + +def test_get_template_soc_forbidden( + client: FlaskClient, admin_token: str, soc_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.get(f"/api/templates/{created['id']}", headers=_h(soc_token)) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Patch +# --------------------------------------------------------------------------- + + +def test_patch_template_name(client: FlaskClient, admin_token: str) -> None: + created = _make_template(client, admin_token) + resp = client.patch( + f"/api/templates/{created['id']}", + headers=_h(admin_token), + json={"name": "Renamed"}, + ) + assert resp.status_code == 200 + assert resp.get_json()["name"] == "Renamed" + assert resp.get_json()["updated_at"] is not None + + +def test_patch_template_empty_name_rejected( + client: FlaskClient, admin_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.patch( + f"/api/templates/{created['id']}", + headers=_h(admin_token), + json={"name": ""}, + ) + assert resp.status_code == 400 + + +def test_patch_template_unknown_field_rejected( + client: FlaskClient, admin_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.patch( + f"/api/templates/{created['id']}", + headers=_h(admin_token), + json={"bogus_field": "x"}, + ) + assert resp.status_code == 400 + assert "unknown fields" in resp.get_json()["error"] + + +def test_patch_template_duplicate_name_409( + client: FlaskClient, admin_token: str +) -> None: + _make_template(client, admin_token, name="T1") + t2 = _make_template(client, admin_token, name="T2") + resp = client.patch( + f"/api/templates/{t2['id']}", + headers=_h(admin_token), + json={"name": "T1"}, + ) + assert resp.status_code == 409 + + +def test_patch_template_soc_forbidden( + client: FlaskClient, admin_token: str, soc_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.patch( + f"/api/templates/{created['id']}", + headers=_h(soc_token), + json={"name": "X"}, + ) + assert resp.status_code == 403 + + +def test_patch_template_not_found(client: FlaskClient, admin_token: str) -> None: + resp = client.patch( + "/api/templates/9999", headers=_h(admin_token), json={"name": "X"} + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Delete +# --------------------------------------------------------------------------- + + +def test_delete_template( + client: FlaskClient, app, admin_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.delete(f"/api/templates/{created['id']}", headers=_h(admin_token)) + assert resp.status_code == 204 + with app.app_context(): + assert db.session.get(SimulationTemplate, created["id"]) is None + + +def test_delete_template_not_found(client: FlaskClient, admin_token: str) -> None: + resp = client.delete("/api/templates/9999", headers=_h(admin_token)) + assert resp.status_code == 404 + + +def test_delete_template_soc_forbidden( + client: FlaskClient, admin_token: str, soc_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.delete(f"/api/templates/{created['id']}", headers=_h(soc_token)) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# List returns ordered by name +# --------------------------------------------------------------------------- + + +def test_list_templates_ordered_by_name( + client: FlaskClient, admin_token: str +) -> None: + for name in ("Zebra", "Alpha", "Midpoint"): + _make_template(client, admin_token, name=name) + body = client.get("/api/templates", headers=_h(admin_token)).get_json() + names = [t["name"] for t in body] + assert names == sorted(names) diff --git a/backend/tests/test_simulations_from_template.py b/backend/tests/test_simulations_from_template.py new file mode 100644 index 0000000..1f6a0d6 --- /dev/null +++ b/backend/tests/test_simulations_from_template.py @@ -0,0 +1,195 @@ +"""Tests for creating simulations from a template (POST /api/engagements//simulations).""" +from __future__ import annotations + +from pathlib import Path + +from alembic.operations import Operations +from alembic.runtime.migration import MigrationContext +from flask.testing import FlaskClient +from sqlalchemy import create_engine, text + +from backend.tests.conftest import auth_headers as _h + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_engagement(client: FlaskClient, token: str) -> dict: + resp = client.post( + "/api/engagements", + headers=_h(token), + json={"name": "Op Bravo", "start_date": "2026-06-01"}, + ) + assert resp.status_code == 201, resp.get_json() + return resp.get_json() + + +def _make_template(client: FlaskClient, token: str, **kw) -> dict: + payload = {"name": "Base Template", **kw} + resp = client.post("/api/templates", headers=_h(token), json=payload) + assert resp.status_code == 201, resp.get_json() + return resp.get_json() + + +def _make_sim(client: FlaskClient, token: str, eid: int, **kw) -> dict: + payload = {"name": "Sim From Template", **kw} + resp = client.post( + f"/api/engagements/{eid}/simulations", headers=_h(token), json=payload + ) + assert resp.status_code == 201, resp.get_json() + return resp.get_json() + + +# --------------------------------------------------------------------------- +# Instantiation +# --------------------------------------------------------------------------- + + +def test_create_simulation_from_template_copies_fields( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + tmpl = _make_template( + client, + admin_token, + description="template desc", + commands="template cmd", + prerequisites="template prereq", + ) + sim = _make_sim(client, admin_token, eng["id"], template_id=tmpl["id"]) + + assert sim["description"] == "template desc" + assert sim["commands"] == "template cmd" + assert sim["prerequisites"] == "template prereq" + assert sim["techniques"] == [] + assert sim["tactics"] == [] + assert sim["status"] == "pending" + + +def test_create_simulation_name_overrides_template( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + tmpl = _make_template(client, admin_token) + sim = _make_sim( + client, admin_token, eng["id"], name="Custom Name", template_id=tmpl["id"] + ) + assert sim["name"] == "Custom Name" + + +def test_create_simulation_template_not_found( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + resp = client.post( + f"/api/engagements/{eng['id']}/simulations", + headers=_h(admin_token), + json={"name": "S", "template_id": 9999}, + ) + assert resp.status_code == 404 + assert "Template not found" in resp.get_json()["error"] + + +def test_create_simulation_without_template_unaffected( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"]) + assert sim["description"] is None + assert sim["commands"] is None + assert sim["prerequisites"] is None + + +def test_create_simulation_from_template_status_is_pending( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + tmpl = _make_template(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"], template_id=tmpl["id"]) + assert sim["status"] == "pending" + + +def test_delete_template_does_not_cascade_to_simulations( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + tmpl = _make_template(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"], template_id=tmpl["id"]) + sid = sim["id"] + + # Delete the template. + del_resp = client.delete( + f"/api/templates/{tmpl['id']}", headers=_h(admin_token) + ) + assert del_resp.status_code == 204 + + # Simulation must still be retrievable. + get_resp = client.get(f"/api/simulations/{sid}", headers=_h(admin_token)) + assert get_resp.status_code == 200 + assert get_resp.get_json()["id"] == sid + + +# --------------------------------------------------------------------------- +# Migration round-trip +# --------------------------------------------------------------------------- + + +def test_migration_0005_round_trip() -> None: + engine = create_engine("sqlite:///:memory:") + migration_file = ( + Path(__file__).parent.parent + / "migrations" + / "versions" + / "0005_simulation_templates.py" + ) + + import importlib.util + + spec = importlib.util.spec_from_file_location("m0005", migration_file) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) # type: ignore[union-attr] + + with engine.begin() as conn: + ctx = MigrationContext.configure(conn) + import alembic.op as op_module + + op_module._proxy = Operations(ctx) # type: ignore[attr-defined] + + # Create users table (FK dependency). + conn.execute( + text( + "CREATE TABLE users (" + "id INTEGER PRIMARY KEY, " + "username TEXT NOT NULL, " + "password_hash TEXT NOT NULL, " + "role TEXT NOT NULL DEFAULT 'redteam', " + "created_at DATETIME" + ")" + ) + ) + + module.upgrade() + + tables_after = conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ).fetchall() + table_names = {r[0] for r in tables_after} + assert "simulation_templates" in table_names + + cols = conn.execute( + text("PRAGMA table_info(simulation_templates)") + ).fetchall() + col_names = {c[1] for c in cols} + for expected in ("id", "name", "techniques", "tactic_ids", "created_by_id"): + assert expected in col_names, f"missing column: {expected}" + + module.downgrade() + + tables_after_down = conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ).fetchall() + table_names_down = {r[0] for r in tables_after_down} + assert "simulation_templates" not in table_names_down