diff --git a/backend/app/api/c2.py b/backend/app/api/c2.py index 15c5060..2564b48 100644 --- a/backend/app/api/c2.py +++ b/backend/app/api/c2.py @@ -21,6 +21,7 @@ from backend.app.models.c2_task import C2Task, C2TaskSource from backend.app.models.simulation import Simulation, SimulationStatus from backend.app.services.c2.adapter import C2Error from backend.app.services.c2.factory import get_adapter +from backend.app.services.c2.mapping import apply_task_to_simulation from backend.app.services.crypto import C2Disabled, decrypt, encrypt from backend.app.services.simulation_workflow import promote_to_in_progress @@ -297,3 +298,72 @@ def execute_simulation(sid: int): for t in created_tasks ] }), 200 + + +# --------------------------------------------------------------------------- +# M3 — poll-on-read task listing +# --------------------------------------------------------------------------- + + +@sims_c2_bp.get("//c2/tasks") +@role_required("admin", "redteam") +def list_simulation_tasks(sid: int): + guard = _crypto_guard() + if guard is not None: + return guard + + sim = db.session.get(Simulation, sid) + if sim is None: + return jsonify({"error": "Simulation not found"}), 404 + + engagement = db.session.get(Engagement, sim.engagement_id) + if engagement is None: + return jsonify({"error": "Engagement not found"}), 404 + + adapter, err = _load_adapter_for_engagement(engagement) + if err is not None: + return err + + tasks: list[C2Task] = C2Task.query.filter_by(simulation_id=sid).all() + + for task in tasks: + if task.completed: + continue + + try: + status = adapter.get_task(task.mythic_task_display_id) + except C2Error: + # Best-effort refresh — skip this task if the adapter fails. + continue + + task.status = status.status + task.completed = status.completed + + if status.completed: + task.completed_at = status.completed_at or datetime.now(UTC) + try: + task.output = adapter.get_task_output(task.mythic_task_display_id) + except C2Error: + task.output = "" + apply_task_to_simulation(task, sim) + + db.session.commit() + + return jsonify({ + "tasks": [ + { + "id": t.id, + "mythic_task_display_id": t.mythic_task_display_id, + "callback_display_id": t.callback_display_id, + "command": t.command, + "params": t.params, + "status": t.status, + "completed": t.completed, + "output": t.output, + "mapping_applied": t.mapping_applied, + "created_at": t.created_at.isoformat() if t.created_at else None, + "completed_at": t.completed_at.isoformat() if t.completed_at else None, + } + for t in tasks + ] + }), 200 diff --git a/backend/app/models/c2_task.py b/backend/app/models/c2_task.py index d87de92..b44a08f 100644 --- a/backend/app/models/c2_task.py +++ b/backend/app/models/c2_task.py @@ -37,6 +37,7 @@ class C2Task(db.Model): # type: ignore[name-defined] db.DateTime, nullable=False, default=lambda: datetime.now(UTC) ) completed_at = db.Column(db.DateTime, nullable=True) + mapping_applied = db.Column(db.Boolean, nullable=False, default=False) simulation = db.relationship( "Simulation", diff --git a/backend/app/services/c2/adapter.py b/backend/app/services/c2/adapter.py index 5ee5460..ab3db97 100644 --- a/backend/app/services/c2/adapter.py +++ b/backend/app/services/c2/adapter.py @@ -4,7 +4,8 @@ from __future__ import annotations import base64 import binascii from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field +from datetime import datetime class C2Error(Exception): @@ -32,6 +33,7 @@ class C2TaskStatus: display_id: int status: str completed: bool + completed_at: datetime | None = field(default=None) @dataclass diff --git a/backend/app/services/c2/fake.py b/backend/app/services/c2/fake.py index 1da4a71..8938d8b 100644 --- a/backend/app/services/c2/fake.py +++ b/backend/app/services/c2/fake.py @@ -8,6 +8,7 @@ from __future__ import annotations from backend.app.services.c2.adapter import ( C2Adapter, C2Callback, + C2Error, C2Health, C2TaskPage, C2TaskStatus, @@ -46,11 +47,17 @@ class FakeAdapter(C2Adapter): """In-memory adapter with deterministic behaviour. Each instance starts with an empty task store and display_ids from 1000. + + get_task() state progression per task (keyed by display_id): + - First call after create_task → submitted, completed=False + - Second and subsequent calls → completed=True, status="completed" """ def __init__(self) -> None: self._tasks: dict[int, dict] = {} self._next_task_id = 1000 + # Tracks how many times get_task has been called per display_id. + self._get_task_calls: dict[int, int] = {} def test_connection(self) -> C2Health: return C2Health(ok=True) @@ -78,20 +85,44 @@ class FakeAdapter(C2Adapter): return tid def get_task(self, task_display_id: int) -> C2TaskStatus: + """Deterministic state progression: first call → submitted, second+ → completed. + + Tracks call count regardless of whether the task was created by this instance, + so the endpoint poll-on-read flow works across separate adapter instantiations. + """ + call_count = self._get_task_calls.get(task_display_id, 0) + 1 + self._get_task_calls[task_display_id] = call_count + task = self._tasks.get(task_display_id) - if task is None: - return C2TaskStatus(display_id=task_display_id, status="unknown", completed=False) + + if call_count >= 2: + completed = True + status = "completed" + if task is not None: + task["status"] = "completed" + task["completed"] = True + else: + completed = False + status = task["status"] if task is not None else "submitted" + return C2TaskStatus( display_id=task_display_id, - status=task["status"], - completed=task["completed"], + status=status, + completed=completed, ) def get_task_output(self, task_display_id: int) -> str: + """Returns deterministic output once task is completed; raises C2Error before that.""" + # Check call count — completed if get_task was called at least twice. + if self._get_task_calls.get(task_display_id, 0) < 2: + # Also allow tasks in _tasks that were explicitly set to completed. + task = self._tasks.get(task_display_id) + if task is None or not task.get("completed", False): + raise C2Error("task not completed") + task = self._tasks.get(task_display_id) - if task is None: - return "" - return task.get("output") or "" + command = task["command"] if task is not None else "unknown" + return f"output for task {task_display_id}: {command}\n" def list_callback_tasks( self, diff --git a/backend/app/services/c2/mapping.py b/backend/app/services/c2/mapping.py new file mode 100644 index 0000000..b19dc8f --- /dev/null +++ b/backend/app/services/c2/mapping.py @@ -0,0 +1,38 @@ +"""C2 task → Simulation output mapping. + +apply_task_to_simulation() writes task output into the simulation's +execution_result field and marks the task as mapping_applied=True so that +the operation is idempotent (safe to call multiple times for the same task). + +Caller is responsible for committing the session. +""" +from __future__ import annotations + +from datetime import UTC, datetime + +from backend.app.models.c2_task import C2Task +from backend.app.models.simulation import Simulation + + +def apply_task_to_simulation(task: C2Task, simulation: Simulation) -> None: + """Write task output into simulation.execution_result (append, newline-separated). + + No-op if task.mapping_applied is already True or task.output is empty. + Marks task.mapping_applied = True on completion. + """ + if task.mapping_applied: + return + + output = (task.output or "").strip() + if not output: + task.mapping_applied = True + return + + existing = (simulation.execution_result or "").rstrip("\n") + if existing: + simulation.execution_result = existing + "\n" + output + else: + simulation.execution_result = output + + simulation.updated_at = datetime.now(UTC) + task.mapping_applied = True diff --git a/backend/app/services/c2/mythic.py b/backend/app/services/c2/mythic.py index eb62c96..d913fc2 100644 --- a/backend/app/services/c2/mythic.py +++ b/backend/app/services/c2/mythic.py @@ -12,6 +12,8 @@ M4: list_callback_tasks() """ from __future__ import annotations +from datetime import datetime + import requests from backend.app.services.c2.adapter import ( @@ -21,6 +23,7 @@ from backend.app.services.c2.adapter import ( C2Health, C2TaskPage, C2TaskStatus, + decode_response_text, ) _HEALTH_QUERY = "{ __typename }" @@ -54,6 +57,28 @@ mutation CreateTask($callback_id: Int!, $command: String!, $params: String!) { } """ +_GET_TASK_QUERY = """ +query GetTask($display_id: Int!) { + task(where: {display_id: {_eq: $display_id}}) { + display_id + status + completed + timestamp + } +} +""" + +_GET_TASK_OUTPUT_QUERY = """ +query GetTaskOutput($display_id: Int!) { + response( + where: {task: {display_id: {_eq: $display_id}}} + order_by: {id: asc} + ) { + response_text + } +} +""" + class MythicAdapter(C2Adapter): """Real Mythic 3.x adapter using GraphQL over HTTP.""" @@ -144,10 +169,52 @@ class MythicAdapter(C2Adapter): return int(task_data["display_id"]) def get_task(self, task_display_id: int) -> C2TaskStatus: - raise NotImplementedError("M3") + """Return current task status from Mythic.""" + try: + data = self._post({ + "query": _GET_TASK_QUERY, + "variables": {"display_id": task_display_id}, + }) + except requests.RequestException as exc: + raise C2Error(str(exc)) from exc + + rows = data.get("data", {}).get("task", []) + if not rows: + raise C2Error(f"task {task_display_id} not found in Mythic") + row = rows[0] + + completed_at: datetime | None = None + if row.get("completed") and row.get("timestamp"): + try: + completed_at = datetime.fromisoformat( + row["timestamp"].replace("Z", "+00:00") + ) + except ValueError: + completed_at = None + + return C2TaskStatus( + display_id=row["display_id"], + status=row["status"], + completed=bool(row.get("completed", False)), + completed_at=completed_at, + ) def get_task_output(self, task_display_id: int) -> str: - raise NotImplementedError("M3") + """Return decoded, concatenated output for a task.""" + try: + data = self._post({ + "query": _GET_TASK_OUTPUT_QUERY, + "variables": {"display_id": task_display_id}, + }) + except requests.RequestException as exc: + raise C2Error(str(exc)) from exc + + rows = data.get("data", {}).get("response", []) + return "".join( + decode_response_text(r["response_text"]) + for r in rows + if r.get("response_text") + ) def list_callback_tasks( self, diff --git a/backend/migrations/versions/0007_c2_task_mapping_applied.py b/backend/migrations/versions/0007_c2_task_mapping_applied.py new file mode 100644 index 0000000..ff60ea4 --- /dev/null +++ b/backend/migrations/versions/0007_c2_task_mapping_applied.py @@ -0,0 +1,30 @@ +"""add mapping_applied column to c2_task + +Revision ID: 0007 +Revises: 0006 +Create Date: 2026-06-10 00:00:00.000000 +""" +import sqlalchemy as sa +from alembic import op + +revision = "0007" +down_revision = "0006" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("c2_task") as batch_op: + batch_op.add_column( + sa.Column( + "mapping_applied", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ) + ) + + +def downgrade() -> None: + with op.batch_alter_table("c2_task") as batch_op: + batch_op.drop_column("mapping_applied") diff --git a/backend/tests/test_c2_adapter_fake_m3.py b/backend/tests/test_c2_adapter_fake_m3.py new file mode 100644 index 0000000..3da8e9a --- /dev/null +++ b/backend/tests/test_c2_adapter_fake_m3.py @@ -0,0 +1,110 @@ +"""FakeAdapter M3 state-progression tests — get_task and get_task_output.""" +from __future__ import annotations + +import pytest + +from backend.app.services.c2.adapter import C2Error +from backend.app.services.c2.fake import FakeAdapter + + +@pytest.fixture() +def adapter() -> FakeAdapter: + return FakeAdapter() + + +@pytest.fixture() +def adapter_with_task(adapter: FakeAdapter) -> tuple[FakeAdapter, int]: + tid = adapter.create_task(callback_display_id=1, command="whoami") + return adapter, tid + + +class TestFakeAdapterGetTaskProgression: + def test_first_call_returns_submitted(self, adapter_with_task): + a, tid = adapter_with_task + status = a.get_task(tid) + assert status.status == "submitted" + assert status.completed is False + + def test_second_call_returns_completed(self, adapter_with_task): + a, tid = adapter_with_task + a.get_task(tid) # first call + status = a.get_task(tid) # second call + assert status.status == "completed" + assert status.completed is True + + def test_subsequent_calls_stay_completed(self, adapter_with_task): + a, tid = adapter_with_task + for _ in range(5): + a.get_task(tid) + status = a.get_task(tid) + assert status.completed is True + + def test_unknown_task_id_returns_submitted_on_first_call(self, adapter): + """A task ID not created by this instance still goes through submitted→completed.""" + status = adapter.get_task(9999) + assert status.display_id == 9999 + assert status.status == "submitted" + assert status.completed is False + + def test_call_counters_are_per_task(self, adapter): + """Two tasks have independent state — completing one does not affect the other.""" + t1 = adapter.create_task(callback_display_id=1, command="whoami") + t2 = adapter.create_task(callback_display_id=1, command="ipconfig") + + # Advance t1 to completed via two calls. + adapter.get_task(t1) + adapter.get_task(t1) + + # t2 first call should still be submitted. + s2 = adapter.get_task(t2) + assert s2.status == "submitted" + assert s2.completed is False + + def test_instances_are_isolated(self): + """Per-instance counters — different FakeAdapter instances don't share state.""" + a1 = FakeAdapter() + a2 = FakeAdapter() + + t1 = a1.create_task(1, "cmd") + t2 = a2.create_task(1, "cmd") + + a1.get_task(t1) + a1.get_task(t1) # a1's task is now completed + + # a2's task with same display_id (both start at 1000) should be independent. + assert t1 == t2 == 1000 + s2 = a2.get_task(t2) + assert s2.status == "submitted" + + +class TestFakeAdapterGetTaskOutput: + def test_raises_before_completed(self, adapter_with_task): + a, tid = adapter_with_task + with pytest.raises(C2Error, match="task not completed"): + a.get_task_output(tid) + + def test_raises_after_first_get_task_call_only(self, adapter_with_task): + a, tid = adapter_with_task + a.get_task(tid) # first call — still submitted + with pytest.raises(C2Error, match="task not completed"): + a.get_task_output(tid) + + def test_returns_output_after_completed(self, adapter_with_task): + a, tid = adapter_with_task + a.get_task(tid) + a.get_task(tid) # now completed + output = a.get_task_output(tid) + assert "whoami" in output + assert str(tid) in output + + def test_output_format(self, adapter): + tid = adapter.create_task(callback_display_id=2, command="ipconfig /all") + adapter.get_task(tid) + adapter.get_task(tid) + output = adapter.get_task_output(tid) + assert output == f"output for task {tid}: ipconfig /all\n" + + def test_unknown_task_raises_c2error(self, adapter): + """Task ID never created and never polled — not completed → C2Error.""" + with pytest.raises(C2Error, match="task not completed"): + adapter.get_task_output(9999) diff --git a/backend/tests/test_c2_adapter_mythic_m3.py b/backend/tests/test_c2_adapter_mythic_m3.py new file mode 100644 index 0000000..e6d5963 --- /dev/null +++ b/backend/tests/test_c2_adapter_mythic_m3.py @@ -0,0 +1,188 @@ +"""MythicAdapter M3 tests — get_task and get_task_output, mocked HTTP.""" +from __future__ import annotations + +import pytest +import requests +import requests_mock as rm_module + +from backend.app.services.c2.adapter import C2Error +from backend.app.services.c2.mythic import MythicAdapter + +_BASE_URL = "https://mythic.lab:7443" +_GQL_URL = _BASE_URL + "/graphql" +_TOKEN = "fake-api-token" + + +@pytest.fixture() +def adapter(): + return MythicAdapter(url=_BASE_URL, api_token=_TOKEN, verify_tls=False) + + +class TestMythicAdapterGetTask: + def test_returns_status_for_incomplete_task(self, adapter): + payload = { + "data": { + "task": [ + { + "display_id": 7, + "status": "processing", + "completed": False, + "timestamp": None, + } + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + status = adapter.get_task(7) + + assert status.display_id == 7 + assert status.status == "processing" + assert status.completed is False + assert status.completed_at is None + + def test_returns_completed_at_for_completed_task(self, adapter): + payload = { + "data": { + "task": [ + { + "display_id": 7, + "status": "completed", + "completed": True, + "timestamp": "2026-06-10T12:00:00Z", + } + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + status = adapter.get_task(7) + + assert status.completed is True + assert status.completed_at is not None + assert status.completed_at.year == 2026 + + def test_raises_when_task_not_found(self, adapter): + payload = {"data": {"task": []}} + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + with pytest.raises(C2Error, match="not found"): + adapter.get_task(999) + + def test_sends_apitoken_header(self, adapter): + payload = { + "data": { + "task": [ + {"display_id": 1, "status": "submitted", "completed": False, "timestamp": None} + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + adapter.get_task(1) + assert m.last_request.headers.get("apitoken") == _TOKEN + + def test_network_error_raises_c2error(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, exc=requests.exceptions.ConnectionError("refused")) + with pytest.raises(C2Error): + adapter.get_task(1) + + def test_no_redirect_followed(self, adapter): + """get_task must not follow HTTP redirects.""" + with rm_module.Mocker() as m: + m.post(_GQL_URL, status_code=301, headers={"Location": "https://evil.example/"}) + with pytest.raises(C2Error): + adapter.get_task(1) + assert len(m.request_history) == 1 + + def test_invalid_timestamp_does_not_crash(self, adapter): + """A malformed timestamp field falls back to completed_at=None without raising.""" + payload = { + "data": { + "task": [ + { + "display_id": 5, + "status": "completed", + "completed": True, + "timestamp": "not-a-date", + } + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + status = adapter.get_task(5) + + assert status.completed is True + assert status.completed_at is None + + +class TestMythicAdapterGetTaskOutput: + def test_returns_decoded_output(self, adapter): + import base64 + encoded = base64.b64encode(b"Administrator\r\n").decode() + payload = { + "data": { + "response": [{"response_text": encoded}] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + output = adapter.get_task_output(7) + + assert "Administrator" in output + + def test_concatenates_multiple_responses(self, adapter): + import base64 + r1 = base64.b64encode(b"line one\n").decode() + r2 = base64.b64encode(b"line two\n").decode() + payload = { + "data": { + "response": [{"response_text": r1}, {"response_text": r2}] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + output = adapter.get_task_output(7) + + assert "line one" in output + assert "line two" in output + + def test_returns_empty_string_when_no_responses(self, adapter): + payload = {"data": {"response": []}} + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + output = adapter.get_task_output(7) + + assert output == "" + + def test_skips_empty_response_text(self, adapter): + import base64 + encoded = base64.b64encode(b"real output").decode() + payload = { + "data": { + "response": [ + {"response_text": ""}, + {"response_text": encoded}, + ] + } + } + with rm_module.Mocker() as m: + m.post(_GQL_URL, json=payload) + output = adapter.get_task_output(7) + + assert output == "real output" + + def test_network_error_raises_c2error(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, exc=requests.exceptions.Timeout("timeout")) + with pytest.raises(C2Error): + adapter.get_task_output(7) + + def test_no_redirect_followed(self, adapter): + with rm_module.Mocker() as m: + m.post(_GQL_URL, status_code=302, headers={"Location": "https://evil.example/"}) + with pytest.raises(C2Error): + adapter.get_task_output(1) + assert len(m.request_history) == 1 diff --git a/backend/tests/test_c2_mapping.py b/backend/tests/test_c2_mapping.py new file mode 100644 index 0000000..70368f4 --- /dev/null +++ b/backend/tests/test_c2_mapping.py @@ -0,0 +1,97 @@ +"""Unit tests for apply_task_to_simulation() mapping helper.""" +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import MagicMock + +from backend.app.services.c2.mapping import apply_task_to_simulation + + +def _make_task(output: str | None = "whoami output", mapping_applied: bool = False) -> MagicMock: + task = MagicMock() + task.output = output + task.mapping_applied = mapping_applied + return task + + +def _make_sim(execution_result: str | None = None) -> MagicMock: + sim = MagicMock() + sim.execution_result = execution_result + sim.updated_at = None + return sim + + +class TestApplyTaskToSimulation: + def test_appends_output_to_empty_simulation(self): + task = _make_task(output="whoami output") + sim = _make_sim(execution_result=None) + + apply_task_to_simulation(task, sim) + + assert sim.execution_result == "whoami output" + assert task.mapping_applied is True + + def test_appends_with_newline_separator(self): + task = _make_task(output="second result") + sim = _make_sim(execution_result="first result") + + apply_task_to_simulation(task, sim) + + assert sim.execution_result == "first result\nsecond result" + + def test_idempotent_when_already_applied(self): + task = _make_task(output="some output", mapping_applied=True) + sim = _make_sim(execution_result="existing") + + apply_task_to_simulation(task, sim) + + # execution_result must not be modified. + assert sim.execution_result == "existing" + + def test_no_op_when_output_is_empty_string(self): + task = _make_task(output="") + sim = _make_sim(execution_result="existing") + + apply_task_to_simulation(task, sim) + + assert sim.execution_result == "existing" + # Still marks mapping_applied so we don't revisit it. + assert task.mapping_applied is True + + def test_no_op_when_output_is_none(self): + task = _make_task(output=None) + sim = _make_sim(execution_result="existing") + + apply_task_to_simulation(task, sim) + + assert sim.execution_result == "existing" + assert task.mapping_applied is True + + def test_strips_trailing_newlines_from_existing(self): + """Existing execution_result with trailing newlines should not cause double blank lines.""" + task = _make_task(output="new output") + sim = _make_sim(execution_result="old output\n\n") + + apply_task_to_simulation(task, sim) + + assert sim.execution_result == "old output\nnew output" + + def test_updated_at_is_set_on_sim(self): + task = _make_task(output="something") + sim = _make_sim(execution_result=None) + before = datetime.now(UTC) + + apply_task_to_simulation(task, sim) + + assert sim.updated_at is not None + assert sim.updated_at >= before + + def test_multiple_tasks_accumulate(self): + sim = _make_sim(execution_result=None) + tasks = [_make_task(output=f"result {i}") for i in range(3)] + + for t in tasks: + apply_task_to_simulation(t, sim) + + lines = sim.execution_result.split("\n") + assert lines == ["result 0", "result 1", "result 2"] diff --git a/backend/tests/test_c2_tasks_list.py b/backend/tests/test_c2_tasks_list.py new file mode 100644 index 0000000..be68c8e --- /dev/null +++ b/backend/tests/test_c2_tasks_list.py @@ -0,0 +1,374 @@ +"""Tests for GET /api/simulations//c2/tasks — poll-on-read endpoint.""" +from __future__ import annotations + +import pytest +from cryptography.fernet import Fernet +from flask import Flask +from flask.testing import FlaskClient + +from backend.app.extensions import db +from backend.app.models.c2_task import C2Task +from backend.app.models.simulation import Simulation +from backend.app.services.c2.adapter import C2Error, C2TaskStatus +from backend.tests.conftest import auth_headers as _h + +_FERNET_KEY = Fernet.generate_key().decode() + + +@pytest.fixture(autouse=True) +def set_encryption_key(monkeypatch): + monkeypatch.setenv("MIMIC_ENCRYPTION_KEY", _FERNET_KEY) + + +@pytest.fixture(autouse=True) +def use_fake_adapter(monkeypatch): + monkeypatch.setenv("MIMIC_C2_ADAPTER", "fake") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_engagement(client: FlaskClient, token: str) -> dict: + resp = client.post( + "/api/engagements", + headers=_h(token), + json={"name": "Op Alpha", "start_date": "2026-06-10"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _put_config(client: FlaskClient, token: str, eid: int) -> None: + resp = client.put( + f"/api/engagements/{eid}/c2-config", + headers=_h(token), + json={"url": "https://c2.internal:7443", "api_token": "s3cr3t", "verify_tls": True}, + ) + assert resp.status_code == 200 + + +def _make_sim(client: FlaskClient, token: str, eid: int) -> dict: + resp = client.post( + f"/api/engagements/{eid}/simulations", + headers=_h(token), + json={"name": "Sim Alpha"}, + ) + assert resp.status_code == 201 + return resp.get_json() + + +def _execute(client: FlaskClient, token: str, sid: int, commands: list, callback_display_id: int = 1): + return client.post( + f"/api/simulations/{sid}/c2/execute", + headers=_h(token), + json={"callback_display_id": callback_display_id, "commands": commands}, + ) + + +def _list_tasks(client: FlaskClient, token: str, sid: int): + return client.get( + f"/api/simulations/{sid}/c2/tasks", + headers=_h(token), + ) + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + + +class TestListTasksHappyPath: + def test_returns_empty_list_when_no_tasks( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + assert resp.status_code == 200 + assert resp.get_json()["tasks"] == [] + + def test_returns_task_after_execute( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + assert resp.status_code == 200 + tasks = resp.get_json()["tasks"] + assert len(tasks) == 1 + assert tasks[0]["command"] == "whoami" + + def test_task_shape( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["hostname"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + task = resp.get_json()["tasks"][0] + for field in ("id", "mythic_task_display_id", "callback_display_id", + "command", "params", "status", "completed", "output", + "mapping_applied", "created_at", "completed_at"): + assert field in task, f"missing field: {field}" + + def test_first_poll_returns_submitted( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + # First GET — FakeAdapter.get_task() first call → submitted. + resp = _list_tasks(client, admin_token, sim["id"]) + task = resp.get_json()["tasks"][0] + assert task["status"] == "submitted" + assert task["completed"] is False + + def test_poll_marks_completed_when_adapter_returns_completed( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """When adapter.get_task returns completed=True the task is updated in DB.""" + from datetime import UTC, datetime + + from backend.app.services.c2 import fake as fake_mod + + def _completed(self, task_display_id: int) -> C2TaskStatus: + return C2TaskStatus( + display_id=task_display_id, + status="completed", + completed=True, + completed_at=datetime.now(UTC), + ) + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + task = resp.get_json()["tasks"][0] + assert task["completed"] is True + assert task["status"] == "completed" + + def test_output_populated_after_completion( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """Output is fetched and stored when task transitions to completed.""" + from datetime import UTC, datetime + + from backend.app.services.c2 import fake as fake_mod + + def _completed(self, task_display_id: int) -> C2TaskStatus: + return C2TaskStatus( + display_id=task_display_id, + status="completed", + completed=True, + completed_at=datetime.now(UTC), + ) + + def _output(self, task_display_id: int) -> str: + return f"whoami result for task {task_display_id}" + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed) + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task_output", _output) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + task = resp.get_json()["tasks"][0] + assert task["output"] is not None + assert "whoami" in task["output"] + + def test_mapping_applied_set_after_completion( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + from datetime import UTC, datetime + + from backend.app.services.c2 import fake as fake_mod + + def _completed(self, task_display_id: int) -> C2TaskStatus: + return C2TaskStatus( + display_id=task_display_id, + status="completed", + completed=True, + completed_at=datetime.now(UTC), + ) + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + _list_tasks(client, admin_token, sim["id"]) + + with app.app_context(): + task = C2Task.query.filter_by(simulation_id=sim["id"]).first() + assert task is not None + assert task.mapping_applied is True + + def test_execution_result_updated_on_simulation( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + from datetime import UTC, datetime + + from backend.app.services.c2 import fake as fake_mod + + def _completed(self, task_display_id: int) -> C2TaskStatus: + return C2TaskStatus( + display_id=task_display_id, + status="completed", + completed=True, + completed_at=datetime.now(UTC), + ) + + def _output(self, task_display_id: int) -> str: + return f"WORKSTATION-01\\whoami output {task_display_id}" + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed) + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task_output", _output) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + _list_tasks(client, admin_token, sim["id"]) + + with app.app_context(): + updated_sim = db.session.get(Simulation, sim["id"]) + assert updated_sim is not None + assert updated_sim.execution_result is not None + assert "whoami" in updated_sim.execution_result + + def test_completed_task_not_re_polled( + self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """Once task.completed=True in DB, subsequent GETs skip polling (no re-poll).""" + from datetime import UTC, datetime + + from backend.app.services.c2 import fake as fake_mod + + call_count = {"n": 0} + + def _completed(self, task_display_id: int) -> C2TaskStatus: + call_count["n"] += 1 + return C2TaskStatus( + display_id=task_display_id, + status="completed", + completed=True, + completed_at=datetime.now(UTC), + ) + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + _list_tasks(client, admin_token, sim["id"]) # 1st GET — marks task completed (1 call) + first_count = call_count["n"] + + _list_tasks(client, admin_token, sim["id"]) # 2nd GET — task already completed, skip poll + + # get_task should NOT have been called again on the 2nd GET. + assert call_count["n"] == first_count, "completed task should not be re-polled" + + resp = _list_tasks(client, admin_token, sim["id"]) + assert resp.status_code == 200 + task = resp.get_json()["tasks"][0] + assert task["completed"] is True + + def test_redteam_can_list_tasks( + self, client: FlaskClient, admin_token: str, redteam_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + resp = _list_tasks(client, redteam_token, sim["id"]) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + + +class TestListTasksErrors: + def test_404_simulation_not_found( + self, client: FlaskClient, admin_token: str + ) -> None: + resp = _list_tasks(client, admin_token, 9999) + assert resp.status_code == 404 + + def test_403_soc_forbidden( + self, client: FlaskClient, admin_token: str, soc_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _list_tasks(client, soc_token, sim["id"]) + assert resp.status_code == 403 + + def test_503_no_encryption_key( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + monkeypatch.delenv("MIMIC_ENCRYPTION_KEY", raising=False) + eng = _make_engagement(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + assert resp.status_code == 503 + + def test_404_no_c2_config( + self, client: FlaskClient, admin_token: str + ) -> None: + eng = _make_engagement(client, admin_token) + sim = _make_sim(client, admin_token, eng["id"]) + + resp = _list_tasks(client, admin_token, sim["id"]) + assert resp.status_code == 404 + + def test_adapter_error_during_poll_is_tolerated( + self, monkeypatch, client: FlaskClient, admin_token: str + ) -> None: + """If get_task raises C2Error during poll, the task is skipped (best-effort).""" + from backend.app.services.c2 import fake as fake_mod + + def _boom(self, task_display_id: int): + raise C2Error("upstream unavailable") + + monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _boom) + + eng = _make_engagement(client, admin_token) + _put_config(client, admin_token, eng["id"]) + sim = _make_sim(client, admin_token, eng["id"]) + _execute(client, admin_token, sim["id"], ["whoami"]) + + # Should still return 200 with the task (un-refreshed status). + resp = _list_tasks(client, admin_token, sim["id"]) + assert resp.status_code == 200 + tasks = resp.get_json()["tasks"] + assert len(tasks) == 1 + # Status is stale (not updated due to error) — still "submitted". + assert tasks[0]["status"] == "submitted" diff --git a/backend/tests/test_migration_0007_c2.py b/backend/tests/test_migration_0007_c2.py new file mode 100644 index 0000000..11b122d --- /dev/null +++ b/backend/tests/test_migration_0007_c2.py @@ -0,0 +1,124 @@ +"""Migration round-trip test for 0007_c2_task_mapping_applied. + +Verifies that upgrade() adds the mapping_applied column and downgrade() removes it. +Uses the resolved-path pattern per lessons.md Sprint 4. +""" +from __future__ import annotations + +import importlib.util +from pathlib import Path + +from alembic.operations import Operations +from alembic.runtime.migration import MigrationContext +from sqlalchemy import create_engine, inspect, text + + +def _load_migration(name: str): + versions_dir = Path(__file__).resolve().parent.parent / "migrations" / "versions" + path = versions_dir / name + spec = importlib.util.spec_from_file_location(name.removesuffix(".py"), path) + assert spec and spec.loader + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + +def _fresh_engine_with_c2_task(): + """In-memory SQLite with c2_task already created (as left by 0006 upgrade).""" + engine = create_engine("sqlite:///:memory:") + with engine.begin() as conn: + conn.execute(text(""" + CREATE TABLE c2_task ( + id INTEGER PRIMARY KEY, + simulation_id INTEGER NOT NULL, + mythic_task_display_id INTEGER NOT NULL, + callback_display_id INTEGER NOT NULL, + command TEXT NOT NULL, + params TEXT, + status TEXT NOT NULL, + completed BOOLEAN NOT NULL DEFAULT 0, + output TEXT, + source TEXT NOT NULL, + created_at DATETIME NOT NULL, + completed_at DATETIME + ) + """)) + return engine + + +def _run_upgrade(engine, migration_mod): + with engine.begin() as conn: + ctx = MigrationContext.configure(conn) + ops = Operations(ctx) + ops._install_proxy() # type: ignore[attr-defined] + try: + migration_mod.upgrade() + finally: + ops._remove_proxy() # type: ignore[attr-defined] + + +def _run_downgrade(engine, migration_mod): + with engine.begin() as conn: + ctx = MigrationContext.configure(conn) + ops = Operations(ctx) + ops._install_proxy() # type: ignore[attr-defined] + try: + migration_mod.downgrade() + finally: + ops._remove_proxy() # type: ignore[attr-defined] + + +class TestMigration0007Upgrade: + def test_mapping_applied_column_added(self): + engine = _fresh_engine_with_c2_task() + mod = _load_migration("0007_c2_task_mapping_applied.py") + _run_upgrade(engine, mod) + + insp = inspect(engine) + cols = {c["name"] for c in insp.get_columns("c2_task")} + assert "mapping_applied" in cols + + def test_mapping_applied_defaults_to_false(self): + engine = _fresh_engine_with_c2_task() + mod = _load_migration("0007_c2_task_mapping_applied.py") + + # Insert a row before upgrading (no mapping_applied column yet). + with engine.begin() as conn: + conn.execute(text( + "INSERT INTO c2_task " + "(simulation_id, mythic_task_display_id, callback_display_id, " + "command, status, completed, source, created_at) " + "VALUES (1, 1000, 1, 'whoami', 'submitted', 0, 'mimic', '2026-01-01')" + )) + + _run_upgrade(engine, mod) + + with engine.begin() as conn: + row = conn.execute( + text("SELECT mapping_applied FROM c2_task WHERE id = 1") + ).fetchone() + assert row is not None + # SQLite stores booleans as 0/1. + assert row[0] == 0 or row[0] is False + + +class TestMigration0007Downgrade: + def test_downgrade_removes_mapping_applied(self): + engine = _fresh_engine_with_c2_task() + mod = _load_migration("0007_c2_task_mapping_applied.py") + _run_upgrade(engine, mod) + _run_downgrade(engine, mod) + + insp = inspect(engine) + cols = {c["name"] for c in insp.get_columns("c2_task")} + assert "mapping_applied" not in cols + + def test_downgrade_does_not_drop_other_columns(self): + engine = _fresh_engine_with_c2_task() + mod = _load_migration("0007_c2_task_mapping_applied.py") + _run_upgrade(engine, mod) + _run_downgrade(engine, mod) + + insp = inspect(engine) + cols = {c["name"] for c in insp.get_columns("c2_task")} + assert {"id", "simulation_id", "command", "status", "completed"} <= cols