From 55f993fa245ca20ef185863f11b3ab0e7f4e1adc Mon Sep 17 00:00:00 2001 From: Knacky Date: Thu, 28 May 2026 07:04:25 +0200 Subject: [PATCH] =?UTF-8?q?fix(backend):=20sprint=205=20post-review=20?= =?UTF-8?q?=E2=80=94=20name=20fallback,=20isinstance=20guards,=20400=20tes?= =?UTF-8?q?ts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - create_simulation: name falls back to template.name when template_id provided and name is absent/empty (AC-27.1) - templates POST/PATCH: isinstance(list) check on technique_ids/tactic_ids before resolving, returns 400 with clear message - 5 new tests: unknown technique_id → 400 (POST+PATCH), unknown tactic_id → 400 (POST+PATCH), name fallback to template.name - mypy: merged template branch into if/else to eliminate union-attr false positives Co-Authored-By: Claude Sonnet 4.6 --- backend/app/api/simulations.py | 29 +++++++---- backend/app/api/templates.py | 8 +++ .../tests/test_simulation_templates_crud.py | 50 +++++++++++++++++++ .../tests/test_simulations_from_template.py | 14 ++++++ 4 files changed, 91 insertions(+), 10 deletions(-) diff --git a/backend/app/api/simulations.py b/backend/app/api/simulations.py index a384577..645dece 100644 --- a/backend/app/api/simulations.py +++ b/backend/app/api/simulations.py @@ -43,17 +43,7 @@ def create_simulation(eid: int): data = request.get_json(silent=True) or {} name = (data.get("name") or "").strip() - if not name: - return jsonify({"error": "name is required"}), 400 - template_id = data.get("template_id") - sim = Simulation( - engagement_id=eid, - name=name, - status=SimulationStatus.PENDING, - created_at=datetime.now(UTC), - created_by_id=g.current_user.id, - ) if template_id is not None: from backend.app.models.simulation_template import SimulationTemplate @@ -61,11 +51,30 @@ def create_simulation(eid: int): tmpl = db.session.get(SimulationTemplate, template_id) if tmpl is None: return jsonify({"error": "Template not found"}), 404 + if not name: + name = tmpl.name + sim = Simulation( + engagement_id=eid, + name=name, + status=SimulationStatus.PENDING, + created_at=datetime.now(UTC), + created_by_id=g.current_user.id, + ) sim.description = tmpl.description sim.commands = tmpl.commands sim.prerequisites = tmpl.prerequisites sim.techniques = list(tmpl.techniques or []) sim.tactic_ids = list(tmpl.tactic_ids or []) + else: + if not name: + return jsonify({"error": "name is required"}), 400 + sim = Simulation( + engagement_id=eid, + name=name, + status=SimulationStatus.PENDING, + created_at=datetime.now(UTC), + created_by_id=g.current_user.id, + ) db.session.add(sim) db.session.commit() diff --git a/backend/app/api/templates.py b/backend/app/api/templates.py index 540a31d..4240011 100644 --- a/backend/app/api/templates.py +++ b/backend/app/api/templates.py @@ -40,6 +40,8 @@ def create_template(): tactic_ids_val: list[str] = [] if "technique_ids" in data: + if not isinstance(data["technique_ids"], list): + return jsonify({"error": "technique_ids must be a list"}), 400 if not mitre_svc.mitre_loaded: return jsonify({"error": "mitre bundle not loaded"}), 503 resolved, err = _resolve_technique_ids(data["technique_ids"]) @@ -48,6 +50,8 @@ def create_template(): techniques = resolved or [] if "tactic_ids" in data: + if not isinstance(data["tactic_ids"], list): + return jsonify({"error": "tactic_ids must be a list"}), 400 resolved_ta, err = _resolve_tactic_ids(data["tactic_ids"]) if err is not None: return err @@ -108,6 +112,8 @@ def update_template(tid: int): setattr(tmpl, field, data[field]) if "technique_ids" in data: + if not isinstance(data["technique_ids"], list): + return jsonify({"error": "technique_ids must be a list"}), 400 if not mitre_svc.mitre_loaded: return jsonify({"error": "mitre bundle not loaded"}), 503 resolved, err = _resolve_technique_ids(data["technique_ids"]) @@ -116,6 +122,8 @@ def update_template(tid: int): tmpl.techniques = resolved if "tactic_ids" in data: + if not isinstance(data["tactic_ids"], list): + return jsonify({"error": "tactic_ids must be a list"}), 400 resolved_ta, err = _resolve_tactic_ids(data["tactic_ids"]) if err is not None: return err diff --git a/backend/tests/test_simulation_templates_crud.py b/backend/tests/test_simulation_templates_crud.py index eb6dbb3..51c30ea 100644 --- a/backend/tests/test_simulation_templates_crud.py +++ b/backend/tests/test_simulation_templates_crud.py @@ -97,6 +97,30 @@ def test_create_template_duplicate_name_409( assert "already exists" in resp.get_json()["error"] +def test_create_template_unknown_technique_id_400( + client: FlaskClient, admin_token: str +) -> None: + resp = client.post( + "/api/templates", + headers=_h(admin_token), + json={"name": "T", "technique_ids": ["T9999.999"]}, + ) + assert resp.status_code == 400 + assert "unknown technique id" in resp.get_json()["error"] + + +def test_create_template_unknown_tactic_id_400( + client: FlaskClient, admin_token: str +) -> None: + resp = client.post( + "/api/templates", + headers=_h(admin_token), + json={"name": "T", "tactic_ids": ["TA9999"]}, + ) + assert resp.status_code == 400 + assert "unknown tactic id" in resp.get_json()["error"] + + # --------------------------------------------------------------------------- # Get single # --------------------------------------------------------------------------- @@ -196,6 +220,32 @@ def test_patch_template_not_found(client: FlaskClient, admin_token: str) -> None assert resp.status_code == 404 +def test_patch_template_unknown_technique_id_400( + client: FlaskClient, admin_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.patch( + f"/api/templates/{created['id']}", + headers=_h(admin_token), + json={"technique_ids": ["T9999.999"]}, + ) + assert resp.status_code == 400 + assert "unknown technique id" in resp.get_json()["error"] + + +def test_patch_template_unknown_tactic_id_400( + client: FlaskClient, admin_token: str +) -> None: + created = _make_template(client, admin_token) + resp = client.patch( + f"/api/templates/{created['id']}", + headers=_h(admin_token), + json={"tactic_ids": ["TA9999"]}, + ) + assert resp.status_code == 400 + assert "unknown tactic id" in resp.get_json()["error"] + + # --------------------------------------------------------------------------- # Delete # --------------------------------------------------------------------------- diff --git a/backend/tests/test_simulations_from_template.py b/backend/tests/test_simulations_from_template.py index 1f6a0d6..f0e0a3d 100644 --- a/backend/tests/test_simulations_from_template.py +++ b/backend/tests/test_simulations_from_template.py @@ -78,6 +78,20 @@ def test_create_simulation_name_overrides_template( assert sim["name"] == "Custom Name" +def test_create_simulation_name_falls_back_to_template_name( + client: FlaskClient, admin_token: str +) -> None: + eng = _make_engagement(client, admin_token) + tmpl = _make_template(client, admin_token, name="Recon Template") + resp = client.post( + f"/api/engagements/{eng['id']}/simulations", + headers=_h(admin_token), + json={"template_id": tmpl["id"]}, + ) + assert resp.status_code == 201 + assert resp.get_json()["name"] == "Recon Template" + + def test_create_simulation_template_not_found( client: FlaskClient, admin_token: str ) -> None: