From 8f23f59601534603d6ea63719409fdf2fbfb233c Mon Sep 17 00:00:00 2001 From: Knacky Date: Wed, 10 Jun 2026 20:09:29 +0200 Subject: [PATCH] feat(backend): c2 callback history + task import (sprint 8 M4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Command source decision: extended C2TaskStatus with command: str | None (default None). Added command_name to _GET_TASK_QUERY so get_task() returns command in a single round-trip — no separate history fetch needed on import. 4-line change, zero cascading test impact. adapter.py: - C2TaskStatus: add command: str | None = None field - C2HistoricalTask: new dataclass (display_id, command, params, status, completed, timestamp) for history rows - C2TaskPage.items: typed as list[C2HistoricalTask] (was list[dict]) mythic.py: - _GET_TASK_QUERY: add command_name field - _LIST_CALLBACK_TASKS_QUERY: new query (order_by id desc, limit/offset) - _COUNT_CALLBACK_TASKS_QUERY: new aggregate query for total - get_task(): surfaces command_name as status.command - list_callback_tasks(): two _post() calls (tasks + count), allow_redirects=False fake.py: - _FAKE_HISTORY: frozen deterministic history (cb1=12, cb2=0, cb3=5 tasks) - list_callback_tasks(): serves from _FAKE_HISTORY, pagination applied - get_task(): returns command from _tasks dict api/c2.py: - GET /api/engagements//c2/callbacks//history: page+page_size defaults 1/25, cap 100, reject <1, 502 on adapter error - POST /api/simulations//c2/import: idempotent per (sim,mythic_id) pair, source=import, completed tasks get output+mapping_applied, incomplete tasks stored for poll-on-read pickup, auto-transition pending→in_progress 60 new tests (456 total); pytest/ruff/mypy all green Co-Authored-By: Claude Sonnet 4.6 --- backend/app/api/c2.py | 148 ++++++++ backend/app/services/c2/adapter.py | 16 +- backend/app/services/c2/fake.py | 45 ++- backend/app/services/c2/mythic.py | 70 +++- backend/tests/test_c2_adapter_fake_m4.py | 75 ++++ backend/tests/test_c2_adapter_mythic_m4.py | 167 ++++++++ backend/tests/test_c2_history.py | 215 +++++++++++ backend/tests/test_c2_import.py | 418 +++++++++++++++++++++ 8 files changed, 1146 insertions(+), 8 deletions(-) create mode 100644 backend/tests/test_c2_adapter_fake_m4.py create mode 100644 backend/tests/test_c2_adapter_mythic_m4.py create mode 100644 backend/tests/test_c2_history.py create mode 100644 backend/tests/test_c2_import.py diff --git a/backend/app/api/c2.py b/backend/app/api/c2.py index 2564b48..1ce764a 100644 --- a/backend/app/api/c2.py +++ b/backend/app/api/c2.py @@ -367,3 +367,151 @@ def list_simulation_tasks(sid: int): for t in tasks ] }), 200 + + +# --------------------------------------------------------------------------- +# M4 — callback history + task import +# --------------------------------------------------------------------------- + + +@c2_bp.get("//c2/callbacks//history") +@role_required("admin", "redteam") +def list_callback_history(eid: int, cid: int): + guard = _crypto_guard() + if guard is not None: + return guard + + engagement = db.session.get(Engagement, eid) + if engagement is None: + return jsonify({"error": "Engagement not found"}), 404 + + # Validate pagination params. + try: + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 25)) + except (ValueError, TypeError): + return jsonify({"error": "page and page_size must be integers"}), 400 + + if page < 1 or page_size < 1: + return jsonify({"error": "page and page_size must be >= 1"}), 400 + if page_size > 100: + return jsonify({"error": "page_size must be <= 100"}), 400 + + adapter, err = _load_adapter_for_engagement(engagement) + if err is not None: + return err + + try: + page_result = adapter.list_callback_tasks( + callback_display_id=cid, + page=page, + page_size=page_size, + ) + except C2Error as exc: + return jsonify({"error": str(exc)}), 502 + + return jsonify({ + "tasks": [ + { + "display_id": t.display_id, + "command": t.command, + "params": t.params, + "status": t.status, + "completed": t.completed, + "timestamp": t.timestamp, + } + for t in page_result.items + ], + "total": page_result.total, + "page": page_result.page, + "page_size": page_result.page_size, + }), 200 + + +@sims_c2_bp.post("//c2/import") +@role_required("admin", "redteam") +def import_tasks(sid: int): + guard = _crypto_guard() + if guard is not None: + return guard + + sim = db.session.get(Simulation, sid) + if sim is None: + return jsonify({"error": "Simulation not found"}), 404 + + if sim.status == SimulationStatus.DONE: + return jsonify({"error": "simulation is done — reopen first"}), 409 + + data = request.get_json(silent=True) or {} + callback_display_id = data.get("callback_display_id") + task_display_ids = data.get("task_display_ids") + + if not isinstance(callback_display_id, int): + return jsonify({"error": "callback_display_id must be an integer"}), 400 + if not isinstance(task_display_ids, list) or len(task_display_ids) == 0: + return jsonify({"error": "task_display_ids must be a non-empty list"}), 400 + for tid in task_display_ids: + if not isinstance(tid, int): + return jsonify({"error": "each task_display_id must be an integer"}), 400 + + engagement = db.session.get(Engagement, sim.engagement_id) + if engagement is None: + return jsonify({"error": "Engagement not found"}), 404 + + adapter, err = _load_adapter_for_engagement(engagement) + if err is not None: + return err + + imported_count = 0 + skipped_count = 0 + + try: + for task_display_id in task_display_ids: + # Idempotency: skip if already imported for this simulation. + existing = C2Task.query.filter_by( + simulation_id=sid, + mythic_task_display_id=task_display_id, + ).first() + if existing is not None: + skipped_count += 1 + continue + + status = adapter.get_task(task_display_id) + task = C2Task( + simulation_id=sid, + mythic_task_display_id=task_display_id, + callback_display_id=callback_display_id, + command=status.command or "", + params=None, + status=status.status, + completed=status.completed, + source=C2TaskSource.IMPORT, + created_at=datetime.now(UTC), + mapping_applied=False, + ) + + if status.completed: + task.completed_at = status.completed_at or datetime.now(UTC) + try: + task.output = adapter.get_task_output(task_display_id) + except C2Error: + task.output = "" + db.session.add(task) + db.session.flush() + apply_task_to_simulation(task, sim) + task.mapping_applied = True + else: + db.session.add(task) + + imported_count += 1 + + except C2Error as exc: + db.session.rollback() + return jsonify({"error": str(exc)}), 502 + + # Auto-transition pending → in_progress when at least one task was imported. + if imported_count > 0: + promote_to_in_progress(sim) + + db.session.commit() + return jsonify({"imported": imported_count, "skipped": skipped_count}), 200 diff --git a/backend/app/services/c2/adapter.py b/backend/app/services/c2/adapter.py index ab3db97..6d45a82 100644 --- a/backend/app/services/c2/adapter.py +++ b/backend/app/services/c2/adapter.py @@ -34,11 +34,25 @@ class C2TaskStatus: status: str completed: bool completed_at: datetime | None = field(default=None) + # command_name is populated by get_task() so import doesn't need a second round-trip. + command: str | None = field(default=None) + + +@dataclass +class C2HistoricalTask: + """A task entry from callback history (carries command + params, unlike C2TaskStatus).""" + + display_id: int + command: str + params: str | None + status: str + completed: bool + timestamp: str | None # ISO-8601 or None @dataclass class C2TaskPage: - items: list[dict] # raw task dicts from Mythic + items: list[C2HistoricalTask] total: int page: int page_size: int diff --git a/backend/app/services/c2/fake.py b/backend/app/services/c2/fake.py index 8938d8b..7835d9c 100644 --- a/backend/app/services/c2/fake.py +++ b/backend/app/services/c2/fake.py @@ -10,10 +10,45 @@ from backend.app.services.c2.adapter import ( C2Callback, C2Error, C2Health, + C2HistoricalTask, C2TaskPage, C2TaskStatus, ) +# Frozen base timestamp — all fake history tasks share this prefix for determinism. +_BASE_TS = "2026-06-10T00:00:00Z" + +# Deterministic history for list_callback_tasks: +# callback 1 → 12 tasks, callback 2 → 0 tasks, callback 3 → 5 tasks. +# Commands cycle through a fixed set; even-indexed tasks are completed. +_HISTORY_COMMANDS = ["whoami", "hostname", "id", "ipconfig", "net user", "pwd"] + +_FAKE_HISTORY: dict[int, list[C2HistoricalTask]] = { + 1: [ + C2HistoricalTask( + display_id=100 + i, + command=_HISTORY_COMMANDS[i % len(_HISTORY_COMMANDS)], + params=None, + status="completed" if i % 2 == 0 else "submitted", + completed=i % 2 == 0, + timestamp=_BASE_TS if i % 2 == 0 else None, + ) + for i in range(12) + ], + 2: [], + 3: [ + C2HistoricalTask( + display_id=200 + i, + command=_HISTORY_COMMANDS[i % len(_HISTORY_COMMANDS)], + params=None, + status="completed" if i % 2 == 0 else "submitted", + completed=i % 2 == 0, + timestamp=_BASE_TS if i % 2 == 0 else None, + ) + for i in range(5) + ], +} + # Three fixed callbacks the test suite can pin against. _FAKE_CALLBACKS = [ C2Callback( @@ -109,6 +144,7 @@ class FakeAdapter(C2Adapter): display_id=task_display_id, status=status, completed=completed, + command=task["command"] if task is not None else None, ) def get_task_output(self, task_display_id: int) -> str: @@ -130,14 +166,11 @@ class FakeAdapter(C2Adapter): page: int = 1, page_size: int = 25, ) -> C2TaskPage: - items = [ - t for t in self._tasks.values() - if t["callback_display_id"] == callback_display_id - ] + all_items = _FAKE_HISTORY.get(callback_display_id, []) start = (page - 1) * page_size return C2TaskPage( - items=items[start : start + page_size], - total=len(items), + items=all_items[start : start + page_size], + total=len(all_items), page=page, page_size=page_size, ) diff --git a/backend/app/services/c2/mythic.py b/backend/app/services/c2/mythic.py index d913fc2..f10bf5d 100644 --- a/backend/app/services/c2/mythic.py +++ b/backend/app/services/c2/mythic.py @@ -21,6 +21,7 @@ from backend.app.services.c2.adapter import ( C2Callback, C2Error, C2Health, + C2HistoricalTask, C2TaskPage, C2TaskStatus, decode_response_text, @@ -61,6 +62,7 @@ _GET_TASK_QUERY = """ query GetTask($display_id: Int!) { task(where: {display_id: {_eq: $display_id}}) { display_id + command_name status completed timestamp @@ -68,6 +70,34 @@ query GetTask($display_id: Int!) { } """ +_LIST_CALLBACK_TASKS_QUERY = """ +query ListCallbackTasks($callback_display_id: Int!, $limit: Int!, $offset: Int!) { + task( + where: {callback: {display_id: {_eq: $callback_display_id}}} + order_by: {id: desc} + limit: $limit + offset: $offset + ) { + display_id + command_name + params + status + completed + timestamp + } +} +""" + +_COUNT_CALLBACK_TASKS_QUERY = """ +query CountCallbackTasks($callback_display_id: Int!) { + task_aggregate(where: {callback: {display_id: {_eq: $callback_display_id}}}) { + aggregate { + count + } + } +} +""" + _GET_TASK_OUTPUT_QUERY = """ query GetTaskOutput($display_id: Int!) { response( @@ -197,6 +227,7 @@ class MythicAdapter(C2Adapter): status=row["status"], completed=bool(row.get("completed", False)), completed_at=completed_at, + command=row.get("command_name") or None, ) def get_task_output(self, task_display_id: int) -> str: @@ -222,4 +253,41 @@ class MythicAdapter(C2Adapter): page: int = 1, page_size: int = 25, ) -> C2TaskPage: - raise NotImplementedError("M4") + """Return a paginated, most-recent-first history of tasks for a callback.""" + offset = (page - 1) * page_size + try: + data = self._post({ + "query": _LIST_CALLBACK_TASKS_QUERY, + "variables": { + "callback_display_id": callback_display_id, + "limit": page_size, + "offset": offset, + }, + }) + count_data = self._post({ + "query": _COUNT_CALLBACK_TASKS_QUERY, + "variables": {"callback_display_id": callback_display_id}, + }) + except requests.RequestException as exc: + raise C2Error(str(exc)) from exc + + rows = data.get("data", {}).get("task", []) + total: int = ( + count_data.get("data", {}) + .get("task_aggregate", {}) + .get("aggregate", {}) + .get("count", 0) + ) + + items = [ + C2HistoricalTask( + display_id=r["display_id"], + command=r.get("command_name") or "", + params=r.get("params") or None, + status=r.get("status") or "", + completed=bool(r.get("completed", False)), + timestamp=r.get("timestamp") or None, + ) + for r in rows + ] + return C2TaskPage(items=items, total=total, page=page, page_size=page_size) diff --git a/backend/tests/test_c2_adapter_fake_m4.py b/backend/tests/test_c2_adapter_fake_m4.py new file mode 100644 index 0000000..e1e88ba --- /dev/null +++ b/backend/tests/test_c2_adapter_fake_m4.py @@ -0,0 +1,75 @@ +"""FakeAdapter M4 tests — list_callback_tasks pagination.""" +from __future__ import annotations + +import pytest + +from backend.app.services.c2.adapter import C2HistoricalTask +from backend.app.services.c2.fake import FakeAdapter + + +@pytest.fixture() +def adapter() -> FakeAdapter: + return FakeAdapter() + + +class TestFakeAdapterListCallbackTasks: + def test_callback_1_returns_12_total(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=1, page=1, page_size=25) + assert page.total == 12 + + def test_callback_2_returns_0_tasks(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=2, page=1, page_size=25) + assert page.total == 0 + assert page.items == [] + + def test_callback_3_returns_5_tasks(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=3, page=1, page_size=25) + assert page.total == 5 + assert len(page.items) == 5 + + def test_items_are_c2_historical_task_instances(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=1, page=1, page_size=5) + for item in page.items: + assert isinstance(item, C2HistoricalTask) + + def test_pagination_page1(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=1, page=1, page_size=5) + assert len(page.items) == 5 + assert page.page == 1 + assert page.page_size == 5 + + def test_pagination_page2(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=1, page=2, page_size=5) + assert len(page.items) == 5 + assert page.page == 2 + + def test_pagination_last_page_partial(self, adapter): + # 12 tasks, page_size=5 → page 3 has 2 items. + page = adapter.list_callback_tasks(callback_display_id=1, page=3, page_size=5) + assert len(page.items) == 2 + assert page.total == 12 + + def test_pagination_beyond_range_returns_empty(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=1, page=99, page_size=25) + assert len(page.items) == 0 + assert page.total == 12 + + def test_history_is_deterministic_across_instances(self): + a1 = FakeAdapter() + a2 = FakeAdapter() + p1 = a1.list_callback_tasks(callback_display_id=1, page=1, page_size=25) + p2 = a2.list_callback_tasks(callback_display_id=1, page=1, page_size=25) + assert [t.display_id for t in p1.items] == [t.display_id for t in p2.items] + + def test_completed_and_submitted_mix(self, adapter): + """Callback 1 has alternating completed/submitted tasks (even=completed).""" + page = adapter.list_callback_tasks(callback_display_id=1, page=1, page_size=12) + completed = [t for t in page.items if t.completed] + submitted = [t for t in page.items if not t.completed] + assert len(completed) == 6 + assert len(submitted) == 6 + + def test_unknown_callback_returns_empty(self, adapter): + page = adapter.list_callback_tasks(callback_display_id=999, page=1, page_size=25) + assert page.total == 0 + assert page.items == [] diff --git a/backend/tests/test_c2_adapter_mythic_m4.py b/backend/tests/test_c2_adapter_mythic_m4.py new file mode 100644 index 0000000..5103185 --- /dev/null +++ b/backend/tests/test_c2_adapter_mythic_m4.py @@ -0,0 +1,167 @@ +"""MythicAdapter M4 tests — list_callback_tasks, mocked HTTP.""" +from __future__ import annotations + +import pytest +import requests +import requests_mock as rm_module + +from backend.app.services.c2.adapter import C2Error, C2HistoricalTask +from backend.app.services.c2.mythic import MythicAdapter + +_BASE_URL = "https://mythic.lab:7443" +_GQL_URL = _BASE_URL + "/graphql" +_TOKEN = "fake-api-token" + + +@pytest.fixture() +def adapter(): + return MythicAdapter(url=_BASE_URL, api_token=_TOKEN, verify_tls=False) + + +def _task_list_payload(tasks: list[dict]) -> dict: + return {"data": {"task": tasks}} + + +def _count_payload(count: int) -> dict: + return {"data": {"task_aggregate": {"aggregate": {"count": count}}}} + + +class TestMythicAdapterListCallbackTasks: + def test_returns_tasks_from_graphql(self, adapter): + tasks_payload = _task_list_payload([ + { + "display_id": 7, + "command_name": "whoami", + "params": "", + "status": "completed", + "completed": True, + "timestamp": "2026-06-10T12:00:00Z", + } + ]) + count_payload = _count_payload(1) + + with rm_module.Mocker() as m: + m.post(_GQL_URL, [{"json": tasks_payload}, {"json": count_payload}]) + page = adapter.list_callback_tasks(callback_display_id=1, page=1, page_size=25) + + assert page.total == 1 + assert len(page.items) == 1 + item = page.items[0] + assert isinstance(item, C2HistoricalTask) + assert item.display_id == 7 + assert item.command == "whoami" + assert item.completed is True + + def test_pagination_offset_calculation(self, adapter): + """page=2, page_size=10 → offset=10 must be sent to Mythic.""" + tasks_payload = _task_list_payload([]) + count_payload = _count_payload(0) + + with rm_module.Mocker() as m: + m.post(_GQL_URL, [{"json": tasks_payload}, {"json": count_payload}]) + adapter.list_callback_tasks(callback_display_id=1, page=2, page_size=10) + + # First request is the task list; check variables. + first_body = m.request_history[0].json() + variables = first_body.get("variables", {}) + + assert variables.get("offset") == 10 + assert variables.get("limit") == 10 + + def test_sends_apitoken_header(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, [ + {"json": _task_list_payload([])}, + {"json": _count_payload(0)}, + ]) + adapter.list_callback_tasks(callback_display_id=1) + for req in m.request_history: + assert req.headers.get("apitoken") == _TOKEN + + def test_empty_task_list(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, [ + {"json": _task_list_payload([])}, + {"json": _count_payload(0)}, + ]) + page = adapter.list_callback_tasks(callback_display_id=1) + + assert page.total == 0 + assert page.items == [] + + def test_network_error_raises_c2error(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, exc=requests.exceptions.ConnectionError("refused")) + with pytest.raises(C2Error): + adapter.list_callback_tasks(callback_display_id=1) + + def test_http_error_raises_c2error(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, status_code=500, text="error") + with pytest.raises(C2Error): + adapter.list_callback_tasks(callback_display_id=1) + + def test_no_redirect_followed(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, status_code=301, headers={"Location": "https://evil.example/"}) + with pytest.raises(C2Error): + adapter.list_callback_tasks(callback_display_id=1) + # Both requests (tasks + count) should each only make one attempt. + for req in m.request_history: + assert req.method == "POST" + + def test_page_and_page_size_in_response(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, [ + {"json": _task_list_payload([])}, + {"json": _count_payload(50)}, + ]) + page = adapter.list_callback_tasks(callback_display_id=1, page=3, page_size=10) + + assert page.page == 3 + assert page.page_size == 10 + assert page.total == 50 + + +class TestMythicAdapterGetTaskCommandField: + """Ensure command_name is surfaced via get_task() C2TaskStatus.command.""" + + def test_get_task_returns_command(self, adapter): + payload = { + "data": { + "task": [ + { + "display_id": 7, + "command_name": "shell", + "status": "completed", + "completed": True, + "timestamp": "2026-06-10T12:00:00Z", + } + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + status = adapter.get_task(7) + + assert status.command == "shell" + + def test_get_task_command_none_when_missing(self, adapter): + payload = { + "data": { + "task": [ + { + "display_id": 7, + "command_name": None, + "status": "submitted", + "completed": False, + "timestamp": None, + } + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + status = adapter.get_task(7) + + assert status.command is None diff --git a/backend/tests/test_c2_history.py b/backend/tests/test_c2_history.py new file mode 100644 index 0000000..196947c --- /dev/null +++ b/backend/tests/test_c2_history.py @@ -0,0 +1,215 @@ +"""Tests for GET /api/engagements//c2/callbacks//history.""" +from __future__ import annotations + +import pytest +from cryptography.fernet import Fernet +from flask.testing import FlaskClient + +from backend.app.services.c2.adapter import C2Error +from backend.tests.conftest import auth_headers as _h + +_FERNET_KEY = Fernet.generate_key().decode() + + +@pytest.fixture(autouse=True) +def set_encryption_key(monkeypatch): + monkeypatch.setenv("MIMIC_ENCRYPTION_KEY", _FERNET_KEY) + + +@pytest.fixture(autouse=True) +def use_fake_adapter(monkeypatch): + monkeypatch.setenv("MIMIC_C2_ADAPTER", "fake") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_engagement(client: FlaskClient, token: str) -> dict: + resp = client.post( + "/api/engagements", + headers=_h(token), + json={"name": "Op Alpha", "start_date": "2026-06-10"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _put_config(client: FlaskClient, token: str, eid: int) -> None: + resp = client.put( + f"/api/engagements/{eid}/c2-config", + headers=_h(token), + json={"url": "https://c2.internal:7443", "api_token": "s3cr3t", "verify_tls": True}, + ) + assert resp.status_code == 200 + + +def _history(client: FlaskClient, token: str, eid: int, cid: int, **params): + return client.get( + f"/api/engagements/{eid}/c2/callbacks/{cid}/history", + headers=_h(token), + query_string=params, + ) + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + + +class TestHistoryHappyPath: + def test_returns_200(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.status_code == 200 + + def test_response_shape(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + body = resp.get_json() + assert "tasks" in body + assert "total" in body + assert "page" in body + assert "page_size" in body + + def test_task_shape(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + task = resp.get_json()["tasks"][0] + for field in ("display_id", "command", "params", "status", "completed", "timestamp"): + assert field in task, f"missing field: {field}" + + def test_default_page_is_1(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.get_json()["page"] == 1 + + def test_default_page_size_is_25(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.get_json()["page_size"] == 25 + + def test_callback_1_has_12_total(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.get_json()["total"] == 12 + + def test_callback_2_has_0_tasks(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 2) + body = resp.get_json() + assert body["total"] == 0 + assert body["tasks"] == [] + + def test_pagination_page_size_applied(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1, page=1, page_size=5) + body = resp.get_json() + assert len(body["tasks"]) == 5 + assert body["page_size"] == 5 + + def test_redteam_can_view_history( + self, client: FlaskClient, admin_token: str, redteam_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, redteam_token, eng["id"], 1) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Validation errors +# --------------------------------------------------------------------------- + + +class TestHistoryValidation: + def test_400_page_size_too_large(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1, page_size=101) + assert resp.status_code == 400 + + def test_400_page_zero(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1, page=0) + assert resp.status_code == 400 + + def test_400_page_size_zero(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1, page_size=0) + assert resp.status_code == 400 + + def test_400_page_negative(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1, page=-1) + assert resp.status_code == 400 + + def test_400_page_size_100_is_ok(self, client: FlaskClient, admin_token: str) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1, page_size=100) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Authorization / error cases +# --------------------------------------------------------------------------- + + +class TestHistoryErrors: + def test_403_soc( + self, client: FlaskClient, admin_token: str, soc_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, soc_token, eng["id"], 1) + assert resp.status_code == 403 + + def test_503_no_key( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + monkeypatch.delenv("MIMIC_ENCRYPTION_KEY", raising=False) + eng = _make_engagement(client, admin_token) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.status_code == 503 + + def test_404_engagement_not_found( + self, client: FlaskClient, admin_token: str + ) -> None: + resp = _history(client, admin_token, 9999, 1) + assert resp.status_code == 404 + + def test_404_no_c2_config( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.status_code == 404 + + def test_502_adapter_error( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + from backend.app.services.c2 import fake as fake_mod + + def _boom(self, callback_display_id, page=1, page_size=25): + raise C2Error("upstream error") + + monkeypatch.setattr(fake_mod.FakeAdapter, "list_callback_tasks", _boom) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + resp = _history(client, admin_token, eng["id"], 1) + assert resp.status_code == 502 + assert "upstream error" in resp.get_json().get("error", "") diff --git a/backend/tests/test_c2_import.py b/backend/tests/test_c2_import.py new file mode 100644 index 0000000..9f289c8 --- /dev/null +++ b/backend/tests/test_c2_import.py @@ -0,0 +1,418 @@ +"""Tests for POST /api/simulations//c2/import.""" +from __future__ import annotations + +import pytest +from cryptography.fernet import Fernet +from flask import Flask +from flask.testing import FlaskClient + +from backend.app.extensions import db +from backend.app.models.c2_task import C2Task, C2TaskSource +from backend.app.models.simulation import Simulation, SimulationStatus +from backend.app.services.c2.adapter import C2Error, C2TaskStatus +from backend.tests.conftest import auth_headers as _h + +_FERNET_KEY = Fernet.generate_key().decode() + + +@pytest.fixture(autouse=True) +def set_encryption_key(monkeypatch): + monkeypatch.setenv("MIMIC_ENCRYPTION_KEY", _FERNET_KEY) + + +@pytest.fixture(autouse=True) +def use_fake_adapter(monkeypatch): + monkeypatch.setenv("MIMIC_C2_ADAPTER", "fake") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_engagement(client: FlaskClient, token: str) -> dict: + resp = client.post( + "/api/engagements", + headers=_h(token), + json={"name": "Op Alpha", "start_date": "2026-06-10"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _put_config(client: FlaskClient, token: str, eid: int) -> None: + resp = client.put( + f"/api/engagements/{eid}/c2-config", + headers=_h(token), + json={"url": "https://c2.internal:7443", "api_token": "s3cr3t", "verify_tls": True}, + ) + assert resp.status_code == 200 + + +def _make_sim(client: FlaskClient, token: str, eid: int) -> dict: + resp = client.post( + f"/api/engagements/{eid}/simulations", + headers=_h(token), + json={"name": "Sim Alpha"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _import(client: FlaskClient, token: str, sid: int, task_display_ids: list, callback_display_id: int = 1): + return client.post( + f"/api/simulations/{sid}/c2/import", + headers=_h(token), + json={"callback_display_id": callback_display_id, "task_display_ids": task_display_ids}, + ) + + +def _make_completed_get_task(monkeypatch, command: str = "whoami"): + """Patch FakeAdapter.get_task to return completed=True with a command.""" + from datetime import UTC, datetime + + from backend.app.services.c2 import fake as fake_mod + + def _completed(self, task_display_id: int) -> C2TaskStatus: + return C2TaskStatus( + display_id=task_display_id, + status="completed", + completed=True, + completed_at=datetime.now(UTC), + command=command, + ) + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed) + + def _output(self, task_display_id: int) -> str: + return f"output for {task_display_id}" + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task_output", _output) + + +def _advance_to_review_required(client, token, sid): + client.patch(f"/api/simulations/{sid}", headers=_h(token), json={"name": "Sim Alpha"}) + client.post(f"/api/simulations/{sid}/transition", headers=_h(token), json={"to": "review_required"}) + + +def _advance_to_done(client, admin_token, soc_token, sid): + _advance_to_review_required(client, admin_token, sid) + client.post(f"/api/simulations/{sid}/transition", headers=_h(soc_token), json={"to": "done"}) + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + + +class TestImportHappyPath: + def test_imports_two_completed_tasks( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch, command="whoami") + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, admin_token, sim["id"], [100, 101]) + assert resp.status_code == 200 + body = resp.get_json() + assert body["imported"] == 2 + assert body["skipped"] == 0 + + with app.app_context(): + rows = C2Task.query.filter_by(simulation_id=sim["id"]).all() + assert len(rows) == 2 + + def test_imported_tasks_have_source_import( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + _import(client, admin_token, sim["id"], [100]) + + with app.app_context(): + task = C2Task.query.filter_by(simulation_id=sim["id"]).first() + assert task is not None + assert task.source == C2TaskSource.IMPORT + + def test_completed_tasks_get_mapping_applied( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + _import(client, admin_token, sim["id"], [100]) + + with app.app_context(): + task = C2Task.query.filter_by(simulation_id=sim["id"]).first() + assert task is not None + assert task.mapping_applied is True + + def test_idempotent_import_counts_skipped( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + # First import. + _import(client, admin_token, sim["id"], [100, 101]) + + # Second import with one overlap. + resp = _import(client, admin_token, sim["id"], [100, 102]) + body = resp.get_json() + assert body["imported"] == 1 + assert body["skipped"] == 1 + + def test_auto_transition_pending_to_in_progress( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + assert sim["status"] == "pending" + + _import(client, admin_token, sim["id"], [100]) + + with app.app_context(): + updated = db.session.get(Simulation, sim["id"]) + assert updated is not None + assert updated.status == SimulationStatus.IN_PROGRESS + + def test_no_transition_when_already_in_progress( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + # Advance to in_progress manually. + client.patch( + f"/api/simulations/{sim['id']}", + headers=_h(admin_token), + json={"name": "Sim Alpha"}, + ) + + _import(client, admin_token, sim["id"], [100]) + + with app.app_context(): + updated = db.session.get(Simulation, sim["id"]) + assert updated is not None + assert updated.status == SimulationStatus.IN_PROGRESS + + def test_no_transition_when_review_required( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _advance_to_review_required(client, admin_token, sim["id"]) + + _import(client, admin_token, sim["id"], [100]) + + with app.app_context(): + updated = db.session.get(Simulation, sim["id"]) + assert updated is not None + assert updated.status == SimulationStatus.REVIEW_REQUIRED + + def test_incomplete_task_stored_without_mapping( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """An incomplete task is stored as-is; mapping_applied stays False.""" + from backend.app.services.c2 import fake as fake_mod + + def _submitted(self, task_display_id: int) -> C2TaskStatus: + return C2TaskStatus( + display_id=task_display_id, + status="submitted", + completed=False, + command="shell", + ) + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _submitted) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, admin_token, sim["id"], [200]) + assert resp.status_code == 200 + assert resp.get_json()["imported"] == 1 + + with app.app_context(): + task = C2Task.query.filter_by(simulation_id=sim["id"]).first() + assert task is not None + assert task.completed is False + assert task.mapping_applied is False + assert task.output is None + + def test_command_stored_from_get_task( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """Command field on the stored row comes from adapter.get_task().command.""" + _make_completed_get_task(monkeypatch, command="net user /domain") + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + _import(client, admin_token, sim["id"], [100]) + + with app.app_context(): + task = C2Task.query.filter_by(simulation_id=sim["id"]).first() + assert task is not None + assert task.command == "net user /domain" + + def test_redteam_can_import( + self, monkeypatch, client: FlaskClient, admin_token: str, redteam_token: str + ) -> None: + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, redteam_token, sim["id"], [100]) + assert resp.status_code == 200 + + def test_no_transition_when_all_skipped( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """If imported=0 (all skipped), do not transition pending→in_progress.""" + _make_completed_get_task(monkeypatch) + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + _import(client, admin_token, sim["id"], [100]) # first import + _import(client, admin_token, sim["id"], []) # empty — should 400 before this matters + + # Reset to pending state via a fresh sim (can't undo, just verify the 0-skipped case). + # We test: importing same task again = skipped=1, imported=0 → no double-transition. + resp = _import(client, admin_token, sim["id"], [100]) + body = resp.get_json() + assert body["imported"] == 0 + assert body["skipped"] == 1 + + +# --------------------------------------------------------------------------- +# Validation errors +# --------------------------------------------------------------------------- + + +class TestImportValidation: + def test_400_empty_task_display_ids( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, admin_token, sim["id"], []) + assert resp.status_code == 400 + + def test_400_non_int_task_display_id( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = client.post( + f"/api/simulations/{sim['id']}/c2/import", + headers=_h(admin_token), + json={"callback_display_id": 1, "task_display_ids": ["not-an-int"]}, + ) + assert resp.status_code == 400 + + def test_400_missing_callback_display_id( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = client.post( + f"/api/simulations/{sim['id']}/c2/import", + headers=_h(admin_token), + json={"task_display_ids": [100]}, + ) + assert resp.status_code == 400 + + def test_409_done_simulation( + self, client: FlaskClient, admin_token: str, soc_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _advance_to_done(client, admin_token, soc_token, sim["id"]) + + resp = _import(client, admin_token, sim["id"], [100]) + assert resp.status_code == 409 + + def test_404_simulation_not_found( + self, client: FlaskClient, admin_token: str + ) -> None: + resp = _import(client, admin_token, 9999, [100]) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Authorization / error cases +# --------------------------------------------------------------------------- + + +class TestImportErrors: + def test_403_soc( + self, client: FlaskClient, admin_token: str, soc_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, soc_token, sim["id"], [100]) + assert resp.status_code == 403 + + def test_503_no_key( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + monkeypatch.delenv("MIMIC_ENCRYPTION_KEY", raising=False) + eng = _make_engagement(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, admin_token, sim["id"], [100]) + assert resp.status_code == 503 + + def test_404_no_c2_config( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, admin_token, sim["id"], [100]) + assert resp.status_code == 404 + + def test_502_adapter_error_on_get_task( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + from backend.app.services.c2 import fake as fake_mod + + def _boom(self, task_display_id: int) -> C2TaskStatus: + raise C2Error("Mythic unreachable") + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _boom) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _import(client, admin_token, sim["id"], [100]) + assert resp.status_code == 502 + assert "Mythic unreachable" in resp.get_json().get("error", "")