feat(c2): integrate Mythic command and control (sprint 8) #11
@@ -21,6 +21,7 @@ from backend.app.models.c2_task import C2Task, C2TaskSource
|
||||
from backend.app.models.simulation import Simulation, SimulationStatus
|
||||
from backend.app.services.c2.adapter import C2Error
|
||||
from backend.app.services.c2.factory import get_adapter
|
||||
from backend.app.services.c2.mapping import apply_task_to_simulation
|
||||
from backend.app.services.crypto import C2Disabled, decrypt, encrypt
|
||||
from backend.app.services.simulation_workflow import promote_to_in_progress
|
||||
|
||||
@@ -297,3 +298,72 @@ def execute_simulation(sid: int):
|
||||
for t in created_tasks
|
||||
]
|
||||
}), 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# M3 — poll-on-read task listing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@sims_c2_bp.get("/<int:sid>/c2/tasks")
|
||||
@role_required("admin", "redteam")
|
||||
def list_simulation_tasks(sid: int):
|
||||
guard = _crypto_guard()
|
||||
if guard is not None:
|
||||
return guard
|
||||
|
||||
sim = db.session.get(Simulation, sid)
|
||||
if sim is None:
|
||||
return jsonify({"error": "Simulation not found"}), 404
|
||||
|
||||
engagement = db.session.get(Engagement, sim.engagement_id)
|
||||
if engagement is None:
|
||||
return jsonify({"error": "Engagement not found"}), 404
|
||||
|
||||
adapter, err = _load_adapter_for_engagement(engagement)
|
||||
if err is not None:
|
||||
return err
|
||||
|
||||
tasks: list[C2Task] = C2Task.query.filter_by(simulation_id=sid).all()
|
||||
|
||||
for task in tasks:
|
||||
if task.completed:
|
||||
continue
|
||||
|
||||
try:
|
||||
status = adapter.get_task(task.mythic_task_display_id)
|
||||
except C2Error:
|
||||
# Best-effort refresh — skip this task if the adapter fails.
|
||||
continue
|
||||
|
||||
task.status = status.status
|
||||
task.completed = status.completed
|
||||
|
||||
if status.completed:
|
||||
task.completed_at = status.completed_at or datetime.now(UTC)
|
||||
try:
|
||||
task.output = adapter.get_task_output(task.mythic_task_display_id)
|
||||
except C2Error:
|
||||
task.output = ""
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
"tasks": [
|
||||
{
|
||||
"id": t.id,
|
||||
"mythic_task_display_id": t.mythic_task_display_id,
|
||||
"callback_display_id": t.callback_display_id,
|
||||
"command": t.command,
|
||||
"params": t.params,
|
||||
"status": t.status,
|
||||
"completed": t.completed,
|
||||
"output": t.output,
|
||||
"mapping_applied": t.mapping_applied,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"completed_at": t.completed_at.isoformat() if t.completed_at else None,
|
||||
}
|
||||
for t in tasks
|
||||
]
|
||||
}), 200
|
||||
|
||||
@@ -37,6 +37,7 @@ class C2Task(db.Model): # type: ignore[name-defined]
|
||||
db.DateTime, nullable=False, default=lambda: datetime.now(UTC)
|
||||
)
|
||||
completed_at = db.Column(db.DateTime, nullable=True)
|
||||
mapping_applied = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
simulation = db.relationship(
|
||||
"Simulation",
|
||||
|
||||
@@ -4,7 +4,8 @@ from __future__ import annotations
|
||||
import base64
|
||||
import binascii
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class C2Error(Exception):
|
||||
@@ -32,6 +33,7 @@ class C2TaskStatus:
|
||||
display_id: int
|
||||
status: str
|
||||
completed: bool
|
||||
completed_at: datetime | None = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
from backend.app.services.c2.adapter import (
|
||||
C2Adapter,
|
||||
C2Callback,
|
||||
C2Error,
|
||||
C2Health,
|
||||
C2TaskPage,
|
||||
C2TaskStatus,
|
||||
@@ -46,11 +47,17 @@ class FakeAdapter(C2Adapter):
|
||||
"""In-memory adapter with deterministic behaviour.
|
||||
|
||||
Each instance starts with an empty task store and display_ids from 1000.
|
||||
|
||||
get_task() state progression per task (keyed by display_id):
|
||||
- First call after create_task → submitted, completed=False
|
||||
- Second and subsequent calls → completed=True, status="completed"
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tasks: dict[int, dict] = {}
|
||||
self._next_task_id = 1000
|
||||
# Tracks how many times get_task has been called per display_id.
|
||||
self._get_task_calls: dict[int, int] = {}
|
||||
|
||||
def test_connection(self) -> C2Health:
|
||||
return C2Health(ok=True)
|
||||
@@ -78,20 +85,44 @@ class FakeAdapter(C2Adapter):
|
||||
return tid
|
||||
|
||||
def get_task(self, task_display_id: int) -> C2TaskStatus:
|
||||
"""Deterministic state progression: first call → submitted, second+ → completed.
|
||||
|
||||
Tracks call count regardless of whether the task was created by this instance,
|
||||
so the endpoint poll-on-read flow works across separate adapter instantiations.
|
||||
"""
|
||||
call_count = self._get_task_calls.get(task_display_id, 0) + 1
|
||||
self._get_task_calls[task_display_id] = call_count
|
||||
|
||||
task = self._tasks.get(task_display_id)
|
||||
if task is None:
|
||||
return C2TaskStatus(display_id=task_display_id, status="unknown", completed=False)
|
||||
|
||||
if call_count >= 2:
|
||||
completed = True
|
||||
status = "completed"
|
||||
if task is not None:
|
||||
task["status"] = "completed"
|
||||
task["completed"] = True
|
||||
else:
|
||||
completed = False
|
||||
status = task["status"] if task is not None else "submitted"
|
||||
|
||||
return C2TaskStatus(
|
||||
display_id=task_display_id,
|
||||
status=task["status"],
|
||||
completed=task["completed"],
|
||||
status=status,
|
||||
completed=completed,
|
||||
)
|
||||
|
||||
def get_task_output(self, task_display_id: int) -> str:
|
||||
"""Returns deterministic output once task is completed; raises C2Error before that."""
|
||||
# Check call count — completed if get_task was called at least twice.
|
||||
if self._get_task_calls.get(task_display_id, 0) < 2:
|
||||
# Also allow tasks in _tasks that were explicitly set to completed.
|
||||
task = self._tasks.get(task_display_id)
|
||||
if task is None:
|
||||
return ""
|
||||
return task.get("output") or ""
|
||||
if task is None or not task.get("completed", False):
|
||||
raise C2Error("task not completed")
|
||||
|
||||
task = self._tasks.get(task_display_id)
|
||||
command = task["command"] if task is not None else "unknown"
|
||||
return f"output for task {task_display_id}: {command}\n"
|
||||
|
||||
def list_callback_tasks(
|
||||
self,
|
||||
|
||||
38
backend/app/services/c2/mapping.py
Normal file
38
backend/app/services/c2/mapping.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""C2 task → Simulation output mapping.
|
||||
|
||||
apply_task_to_simulation() writes task output into the simulation's
|
||||
execution_result field and marks the task as mapping_applied=True so that
|
||||
the operation is idempotent (safe to call multiple times for the same task).
|
||||
|
||||
Caller is responsible for committing the session.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.models.c2_task import C2Task
|
||||
from backend.app.models.simulation import Simulation
|
||||
|
||||
|
||||
def apply_task_to_simulation(task: C2Task, simulation: Simulation) -> None:
|
||||
"""Write task output into simulation.execution_result (append, newline-separated).
|
||||
|
||||
No-op if task.mapping_applied is already True or task.output is empty.
|
||||
Marks task.mapping_applied = True on completion.
|
||||
"""
|
||||
if task.mapping_applied:
|
||||
return
|
||||
|
||||
output = (task.output or "").strip()
|
||||
if not output:
|
||||
task.mapping_applied = True
|
||||
return
|
||||
|
||||
existing = (simulation.execution_result or "").rstrip("\n")
|
||||
if existing:
|
||||
simulation.execution_result = existing + "\n" + output
|
||||
else:
|
||||
simulation.execution_result = output
|
||||
|
||||
simulation.updated_at = datetime.now(UTC)
|
||||
task.mapping_applied = True
|
||||
@@ -12,6 +12,8 @@ M4: list_callback_tasks()
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
|
||||
from backend.app.services.c2.adapter import (
|
||||
@@ -21,6 +23,7 @@ from backend.app.services.c2.adapter import (
|
||||
C2Health,
|
||||
C2TaskPage,
|
||||
C2TaskStatus,
|
||||
decode_response_text,
|
||||
)
|
||||
|
||||
_HEALTH_QUERY = "{ __typename }"
|
||||
@@ -54,6 +57,28 @@ mutation CreateTask($callback_id: Int!, $command: String!, $params: String!) {
|
||||
}
|
||||
"""
|
||||
|
||||
_GET_TASK_QUERY = """
|
||||
query GetTask($display_id: Int!) {
|
||||
task(where: {display_id: {_eq: $display_id}}) {
|
||||
display_id
|
||||
status
|
||||
completed
|
||||
timestamp
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
_GET_TASK_OUTPUT_QUERY = """
|
||||
query GetTaskOutput($display_id: Int!) {
|
||||
response(
|
||||
where: {task: {display_id: {_eq: $display_id}}}
|
||||
order_by: {id: asc}
|
||||
) {
|
||||
response_text
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class MythicAdapter(C2Adapter):
|
||||
"""Real Mythic 3.x adapter using GraphQL over HTTP."""
|
||||
@@ -144,10 +169,52 @@ class MythicAdapter(C2Adapter):
|
||||
return int(task_data["display_id"])
|
||||
|
||||
def get_task(self, task_display_id: int) -> C2TaskStatus:
|
||||
raise NotImplementedError("M3")
|
||||
"""Return current task status from Mythic."""
|
||||
try:
|
||||
data = self._post({
|
||||
"query": _GET_TASK_QUERY,
|
||||
"variables": {"display_id": task_display_id},
|
||||
})
|
||||
except requests.RequestException as exc:
|
||||
raise C2Error(str(exc)) from exc
|
||||
|
||||
rows = data.get("data", {}).get("task", [])
|
||||
if not rows:
|
||||
raise C2Error(f"task {task_display_id} not found in Mythic")
|
||||
row = rows[0]
|
||||
|
||||
completed_at: datetime | None = None
|
||||
if row.get("completed") and row.get("timestamp"):
|
||||
try:
|
||||
completed_at = datetime.fromisoformat(
|
||||
row["timestamp"].replace("Z", "+00:00")
|
||||
)
|
||||
except ValueError:
|
||||
completed_at = None
|
||||
|
||||
return C2TaskStatus(
|
||||
display_id=row["display_id"],
|
||||
status=row["status"],
|
||||
completed=bool(row.get("completed", False)),
|
||||
completed_at=completed_at,
|
||||
)
|
||||
|
||||
def get_task_output(self, task_display_id: int) -> str:
|
||||
raise NotImplementedError("M3")
|
||||
"""Return decoded, concatenated output for a task."""
|
||||
try:
|
||||
data = self._post({
|
||||
"query": _GET_TASK_OUTPUT_QUERY,
|
||||
"variables": {"display_id": task_display_id},
|
||||
})
|
||||
except requests.RequestException as exc:
|
||||
raise C2Error(str(exc)) from exc
|
||||
|
||||
rows = data.get("data", {}).get("response", [])
|
||||
return "".join(
|
||||
decode_response_text(r["response_text"])
|
||||
for r in rows
|
||||
if r.get("response_text")
|
||||
)
|
||||
|
||||
def list_callback_tasks(
|
||||
self,
|
||||
|
||||
30
backend/migrations/versions/0007_c2_task_mapping_applied.py
Normal file
30
backend/migrations/versions/0007_c2_task_mapping_applied.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""add mapping_applied column to c2_task
|
||||
|
||||
Revision ID: 0007
|
||||
Revises: 0006
|
||||
Create Date: 2026-06-10 00:00:00.000000
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "0007"
|
||||
down_revision = "0006"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("c2_task") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"mapping_applied",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("c2_task") as batch_op:
|
||||
batch_op.drop_column("mapping_applied")
|
||||
110
backend/tests/test_c2_adapter_fake_m3.py
Normal file
110
backend/tests/test_c2_adapter_fake_m3.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""FakeAdapter M3 state-progression tests — get_task and get_task_output."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.app.services.c2.adapter import C2Error
|
||||
from backend.app.services.c2.fake import FakeAdapter
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter() -> FakeAdapter:
|
||||
return FakeAdapter()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter_with_task(adapter: FakeAdapter) -> tuple[FakeAdapter, int]:
|
||||
tid = adapter.create_task(callback_display_id=1, command="whoami")
|
||||
return adapter, tid
|
||||
|
||||
|
||||
class TestFakeAdapterGetTaskProgression:
|
||||
def test_first_call_returns_submitted(self, adapter_with_task):
|
||||
a, tid = adapter_with_task
|
||||
status = a.get_task(tid)
|
||||
assert status.status == "submitted"
|
||||
assert status.completed is False
|
||||
|
||||
def test_second_call_returns_completed(self, adapter_with_task):
|
||||
a, tid = adapter_with_task
|
||||
a.get_task(tid) # first call
|
||||
status = a.get_task(tid) # second call
|
||||
assert status.status == "completed"
|
||||
assert status.completed is True
|
||||
|
||||
def test_subsequent_calls_stay_completed(self, adapter_with_task):
|
||||
a, tid = adapter_with_task
|
||||
for _ in range(5):
|
||||
a.get_task(tid)
|
||||
status = a.get_task(tid)
|
||||
assert status.completed is True
|
||||
|
||||
def test_unknown_task_id_returns_submitted_on_first_call(self, adapter):
|
||||
"""A task ID not created by this instance still goes through submitted→completed."""
|
||||
status = adapter.get_task(9999)
|
||||
assert status.display_id == 9999
|
||||
assert status.status == "submitted"
|
||||
assert status.completed is False
|
||||
|
||||
def test_call_counters_are_per_task(self, adapter):
|
||||
"""Two tasks have independent state — completing one does not affect the other."""
|
||||
t1 = adapter.create_task(callback_display_id=1, command="whoami")
|
||||
t2 = adapter.create_task(callback_display_id=1, command="ipconfig")
|
||||
|
||||
# Advance t1 to completed via two calls.
|
||||
adapter.get_task(t1)
|
||||
adapter.get_task(t1)
|
||||
|
||||
# t2 first call should still be submitted.
|
||||
s2 = adapter.get_task(t2)
|
||||
assert s2.status == "submitted"
|
||||
assert s2.completed is False
|
||||
|
||||
def test_instances_are_isolated(self):
|
||||
"""Per-instance counters — different FakeAdapter instances don't share state."""
|
||||
a1 = FakeAdapter()
|
||||
a2 = FakeAdapter()
|
||||
|
||||
t1 = a1.create_task(1, "cmd")
|
||||
t2 = a2.create_task(1, "cmd")
|
||||
|
||||
a1.get_task(t1)
|
||||
a1.get_task(t1) # a1's task is now completed
|
||||
|
||||
# a2's task with same display_id (both start at 1000) should be independent.
|
||||
assert t1 == t2 == 1000
|
||||
s2 = a2.get_task(t2)
|
||||
assert s2.status == "submitted"
|
||||
|
||||
|
||||
class TestFakeAdapterGetTaskOutput:
|
||||
def test_raises_before_completed(self, adapter_with_task):
|
||||
a, tid = adapter_with_task
|
||||
with pytest.raises(C2Error, match="task not completed"):
|
||||
a.get_task_output(tid)
|
||||
|
||||
def test_raises_after_first_get_task_call_only(self, adapter_with_task):
|
||||
a, tid = adapter_with_task
|
||||
a.get_task(tid) # first call — still submitted
|
||||
with pytest.raises(C2Error, match="task not completed"):
|
||||
a.get_task_output(tid)
|
||||
|
||||
def test_returns_output_after_completed(self, adapter_with_task):
|
||||
a, tid = adapter_with_task
|
||||
a.get_task(tid)
|
||||
a.get_task(tid) # now completed
|
||||
output = a.get_task_output(tid)
|
||||
assert "whoami" in output
|
||||
assert str(tid) in output
|
||||
|
||||
def test_output_format(self, adapter):
|
||||
tid = adapter.create_task(callback_display_id=2, command="ipconfig /all")
|
||||
adapter.get_task(tid)
|
||||
adapter.get_task(tid)
|
||||
output = adapter.get_task_output(tid)
|
||||
assert output == f"output for task {tid}: ipconfig /all\n"
|
||||
|
||||
def test_unknown_task_raises_c2error(self, adapter):
|
||||
"""Task ID never created and never polled — not completed → C2Error."""
|
||||
with pytest.raises(C2Error, match="task not completed"):
|
||||
adapter.get_task_output(9999)
|
||||
188
backend/tests/test_c2_adapter_mythic_m3.py
Normal file
188
backend/tests/test_c2_adapter_mythic_m3.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""MythicAdapter M3 tests — get_task and get_task_output, mocked HTTP."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import requests_mock as rm_module
|
||||
|
||||
from backend.app.services.c2.adapter import C2Error
|
||||
from backend.app.services.c2.mythic import MythicAdapter
|
||||
|
||||
_BASE_URL = "https://mythic.lab:7443"
|
||||
_GQL_URL = _BASE_URL + "/graphql"
|
||||
_TOKEN = "fake-api-token"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
return MythicAdapter(url=_BASE_URL, api_token=_TOKEN, verify_tls=False)
|
||||
|
||||
|
||||
class TestMythicAdapterGetTask:
|
||||
def test_returns_status_for_incomplete_task(self, adapter):
|
||||
payload = {
|
||||
"data": {
|
||||
"task": [
|
||||
{
|
||||
"display_id": 7,
|
||||
"status": "processing",
|
||||
"completed": False,
|
||||
"timestamp": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
status = adapter.get_task(7)
|
||||
|
||||
assert status.display_id == 7
|
||||
assert status.status == "processing"
|
||||
assert status.completed is False
|
||||
assert status.completed_at is None
|
||||
|
||||
def test_returns_completed_at_for_completed_task(self, adapter):
|
||||
payload = {
|
||||
"data": {
|
||||
"task": [
|
||||
{
|
||||
"display_id": 7,
|
||||
"status": "completed",
|
||||
"completed": True,
|
||||
"timestamp": "2026-06-10T12:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
status = adapter.get_task(7)
|
||||
|
||||
assert status.completed is True
|
||||
assert status.completed_at is not None
|
||||
assert status.completed_at.year == 2026
|
||||
|
||||
def test_raises_when_task_not_found(self, adapter):
|
||||
payload = {"data": {"task": []}}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
with pytest.raises(C2Error, match="not found"):
|
||||
adapter.get_task(999)
|
||||
|
||||
def test_sends_apitoken_header(self, adapter):
|
||||
payload = {
|
||||
"data": {
|
||||
"task": [
|
||||
{"display_id": 1, "status": "submitted", "completed": False, "timestamp": None}
|
||||
]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
adapter.get_task(1)
|
||||
assert m.last_request.headers.get("apitoken") == _TOKEN
|
||||
|
||||
def test_network_error_raises_c2error(self, adapter):
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, exc=requests.exceptions.ConnectionError("refused"))
|
||||
with pytest.raises(C2Error):
|
||||
adapter.get_task(1)
|
||||
|
||||
def test_no_redirect_followed(self, adapter):
|
||||
"""get_task must not follow HTTP redirects."""
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, status_code=301, headers={"Location": "https://evil.example/"})
|
||||
with pytest.raises(C2Error):
|
||||
adapter.get_task(1)
|
||||
assert len(m.request_history) == 1
|
||||
|
||||
def test_invalid_timestamp_does_not_crash(self, adapter):
|
||||
"""A malformed timestamp field falls back to completed_at=None without raising."""
|
||||
payload = {
|
||||
"data": {
|
||||
"task": [
|
||||
{
|
||||
"display_id": 5,
|
||||
"status": "completed",
|
||||
"completed": True,
|
||||
"timestamp": "not-a-date",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
status = adapter.get_task(5)
|
||||
|
||||
assert status.completed is True
|
||||
assert status.completed_at is None
|
||||
|
||||
|
||||
class TestMythicAdapterGetTaskOutput:
|
||||
def test_returns_decoded_output(self, adapter):
|
||||
import base64
|
||||
encoded = base64.b64encode(b"Administrator\r\n").decode()
|
||||
payload = {
|
||||
"data": {
|
||||
"response": [{"response_text": encoded}]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
output = adapter.get_task_output(7)
|
||||
|
||||
assert "Administrator" in output
|
||||
|
||||
def test_concatenates_multiple_responses(self, adapter):
|
||||
import base64
|
||||
r1 = base64.b64encode(b"line one\n").decode()
|
||||
r2 = base64.b64encode(b"line two\n").decode()
|
||||
payload = {
|
||||
"data": {
|
||||
"response": [{"response_text": r1}, {"response_text": r2}]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
output = adapter.get_task_output(7)
|
||||
|
||||
assert "line one" in output
|
||||
assert "line two" in output
|
||||
|
||||
def test_returns_empty_string_when_no_responses(self, adapter):
|
||||
payload = {"data": {"response": []}}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
output = adapter.get_task_output(7)
|
||||
|
||||
assert output == ""
|
||||
|
||||
def test_skips_empty_response_text(self, adapter):
|
||||
import base64
|
||||
encoded = base64.b64encode(b"real output").decode()
|
||||
payload = {
|
||||
"data": {
|
||||
"response": [
|
||||
{"response_text": ""},
|
||||
{"response_text": encoded},
|
||||
]
|
||||
}
|
||||
}
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, json=payload)
|
||||
output = adapter.get_task_output(7)
|
||||
|
||||
assert output == "real output"
|
||||
|
||||
def test_network_error_raises_c2error(self, adapter):
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, exc=requests.exceptions.Timeout("timeout"))
|
||||
with pytest.raises(C2Error):
|
||||
adapter.get_task_output(7)
|
||||
|
||||
def test_no_redirect_followed(self, adapter):
|
||||
with rm_module.Mocker() as m:
|
||||
m.post(_GQL_URL, status_code=302, headers={"Location": "https://evil.example/"})
|
||||
with pytest.raises(C2Error):
|
||||
adapter.get_task_output(1)
|
||||
assert len(m.request_history) == 1
|
||||
97
backend/tests/test_c2_mapping.py
Normal file
97
backend/tests/test_c2_mapping.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Unit tests for apply_task_to_simulation() mapping helper."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.app.services.c2.mapping import apply_task_to_simulation
|
||||
|
||||
|
||||
def _make_task(output: str | None = "whoami output", mapping_applied: bool = False) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.output = output
|
||||
task.mapping_applied = mapping_applied
|
||||
return task
|
||||
|
||||
|
||||
def _make_sim(execution_result: str | None = None) -> MagicMock:
|
||||
sim = MagicMock()
|
||||
sim.execution_result = execution_result
|
||||
sim.updated_at = None
|
||||
return sim
|
||||
|
||||
|
||||
class TestApplyTaskToSimulation:
|
||||
def test_appends_output_to_empty_simulation(self):
|
||||
task = _make_task(output="whoami output")
|
||||
sim = _make_sim(execution_result=None)
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
assert sim.execution_result == "whoami output"
|
||||
assert task.mapping_applied is True
|
||||
|
||||
def test_appends_with_newline_separator(self):
|
||||
task = _make_task(output="second result")
|
||||
sim = _make_sim(execution_result="first result")
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
assert sim.execution_result == "first result\nsecond result"
|
||||
|
||||
def test_idempotent_when_already_applied(self):
|
||||
task = _make_task(output="some output", mapping_applied=True)
|
||||
sim = _make_sim(execution_result="existing")
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
# execution_result must not be modified.
|
||||
assert sim.execution_result == "existing"
|
||||
|
||||
def test_no_op_when_output_is_empty_string(self):
|
||||
task = _make_task(output="")
|
||||
sim = _make_sim(execution_result="existing")
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
assert sim.execution_result == "existing"
|
||||
# Still marks mapping_applied so we don't revisit it.
|
||||
assert task.mapping_applied is True
|
||||
|
||||
def test_no_op_when_output_is_none(self):
|
||||
task = _make_task(output=None)
|
||||
sim = _make_sim(execution_result="existing")
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
assert sim.execution_result == "existing"
|
||||
assert task.mapping_applied is True
|
||||
|
||||
def test_strips_trailing_newlines_from_existing(self):
|
||||
"""Existing execution_result with trailing newlines should not cause double blank lines."""
|
||||
task = _make_task(output="new output")
|
||||
sim = _make_sim(execution_result="old output\n\n")
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
assert sim.execution_result == "old output\nnew output"
|
||||
|
||||
def test_updated_at_is_set_on_sim(self):
|
||||
task = _make_task(output="something")
|
||||
sim = _make_sim(execution_result=None)
|
||||
before = datetime.now(UTC)
|
||||
|
||||
apply_task_to_simulation(task, sim)
|
||||
|
||||
assert sim.updated_at is not None
|
||||
assert sim.updated_at >= before
|
||||
|
||||
def test_multiple_tasks_accumulate(self):
|
||||
sim = _make_sim(execution_result=None)
|
||||
tasks = [_make_task(output=f"result {i}") for i in range(3)]
|
||||
|
||||
for t in tasks:
|
||||
apply_task_to_simulation(t, sim)
|
||||
|
||||
lines = sim.execution_result.split("\n")
|
||||
assert lines == ["result 0", "result 1", "result 2"]
|
||||
374
backend/tests/test_c2_tasks_list.py
Normal file
374
backend/tests/test_c2_tasks_list.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Tests for GET /api/simulations/<id>/c2/tasks — poll-on-read endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from backend.app.extensions import db
|
||||
from backend.app.models.c2_task import C2Task
|
||||
from backend.app.models.simulation import Simulation
|
||||
from backend.app.services.c2.adapter import C2Error, C2TaskStatus
|
||||
from backend.tests.conftest import auth_headers as _h
|
||||
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_encryption_key(monkeypatch):
|
||||
monkeypatch.setenv("MIMIC_ENCRYPTION_KEY", _FERNET_KEY)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def use_fake_adapter(monkeypatch):
|
||||
monkeypatch.setenv("MIMIC_C2_ADAPTER", "fake")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_engagement(client: FlaskClient, token: str) -> dict:
|
||||
resp = client.post(
|
||||
"/api/engagements",
|
||||
headers=_h(token),
|
||||
json={"name": "Op Alpha", "start_date": "2026-06-10"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
return resp.get_json()
|
||||
|
||||
|
||||
def _put_config(client: FlaskClient, token: str, eid: int) -> None:
|
||||
resp = client.put(
|
||||
f"/api/engagements/{eid}/c2-config",
|
||||
headers=_h(token),
|
||||
json={"url": "https://c2.internal:7443", "api_token": "s3cr3t", "verify_tls": True},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def _make_sim(client: FlaskClient, token: str, eid: int) -> dict:
|
||||
resp = client.post(
|
||||
f"/api/engagements/{eid}/simulations",
|
||||
headers=_h(token),
|
||||
json={"name": "Sim Alpha"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
return resp.get_json()
|
||||
|
||||
|
||||
def _execute(client: FlaskClient, token: str, sid: int, commands: list, callback_display_id: int = 1):
|
||||
return client.post(
|
||||
f"/api/simulations/{sid}/c2/execute",
|
||||
headers=_h(token),
|
||||
json={"callback_display_id": callback_display_id, "commands": commands},
|
||||
)
|
||||
|
||||
|
||||
def _list_tasks(client: FlaskClient, token: str, sid: int):
|
||||
return client.get(
|
||||
f"/api/simulations/{sid}/c2/tasks",
|
||||
headers=_h(token),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListTasksHappyPath:
|
||||
def test_returns_empty_list_when_no_tasks(
|
||||
self, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["tasks"] == []
|
||||
|
||||
def test_returns_task_after_execute(
|
||||
self, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
assert resp.status_code == 200
|
||||
tasks = resp.get_json()["tasks"]
|
||||
assert len(tasks) == 1
|
||||
assert tasks[0]["command"] == "whoami"
|
||||
|
||||
def test_task_shape(
|
||||
self, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["hostname"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
task = resp.get_json()["tasks"][0]
|
||||
for field in ("id", "mythic_task_display_id", "callback_display_id",
|
||||
"command", "params", "status", "completed", "output",
|
||||
"mapping_applied", "created_at", "completed_at"):
|
||||
assert field in task, f"missing field: {field}"
|
||||
|
||||
def test_first_poll_returns_submitted(
|
||||
self, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
# First GET — FakeAdapter.get_task() first call → submitted.
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
task = resp.get_json()["tasks"][0]
|
||||
assert task["status"] == "submitted"
|
||||
assert task["completed"] is False
|
||||
|
||||
def test_poll_marks_completed_when_adapter_returns_completed(
|
||||
self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
"""When adapter.get_task returns completed=True the task is updated in DB."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.services.c2 import fake as fake_mod
|
||||
|
||||
def _completed(self, task_display_id: int) -> C2TaskStatus:
|
||||
return C2TaskStatus(
|
||||
display_id=task_display_id,
|
||||
status="completed",
|
||||
completed=True,
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed)
|
||||
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
task = resp.get_json()["tasks"][0]
|
||||
assert task["completed"] is True
|
||||
assert task["status"] == "completed"
|
||||
|
||||
def test_output_populated_after_completion(
|
||||
self, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
"""Output is fetched and stored when task transitions to completed."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.services.c2 import fake as fake_mod
|
||||
|
||||
def _completed(self, task_display_id: int) -> C2TaskStatus:
|
||||
return C2TaskStatus(
|
||||
display_id=task_display_id,
|
||||
status="completed",
|
||||
completed=True,
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
def _output(self, task_display_id: int) -> str:
|
||||
return f"whoami result for task {task_display_id}"
|
||||
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed)
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task_output", _output)
|
||||
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
task = resp.get_json()["tasks"][0]
|
||||
assert task["output"] is not None
|
||||
assert "whoami" in task["output"]
|
||||
|
||||
def test_mapping_applied_set_after_completion(
|
||||
self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.services.c2 import fake as fake_mod
|
||||
|
||||
def _completed(self, task_display_id: int) -> C2TaskStatus:
|
||||
return C2TaskStatus(
|
||||
display_id=task_display_id,
|
||||
status="completed",
|
||||
completed=True,
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed)
|
||||
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
_list_tasks(client, admin_token, sim["id"])
|
||||
|
||||
with app.app_context():
|
||||
task = C2Task.query.filter_by(simulation_id=sim["id"]).first()
|
||||
assert task is not None
|
||||
assert task.mapping_applied is True
|
||||
|
||||
def test_execution_result_updated_on_simulation(
|
||||
self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.services.c2 import fake as fake_mod
|
||||
|
||||
def _completed(self, task_display_id: int) -> C2TaskStatus:
|
||||
return C2TaskStatus(
|
||||
display_id=task_display_id,
|
||||
status="completed",
|
||||
completed=True,
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
def _output(self, task_display_id: int) -> str:
|
||||
return f"WORKSTATION-01\\whoami output {task_display_id}"
|
||||
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed)
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task_output", _output)
|
||||
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
_list_tasks(client, admin_token, sim["id"])
|
||||
|
||||
with app.app_context():
|
||||
updated_sim = db.session.get(Simulation, sim["id"])
|
||||
assert updated_sim is not None
|
||||
assert updated_sim.execution_result is not None
|
||||
assert "whoami" in updated_sim.execution_result
|
||||
|
||||
def test_completed_task_not_re_polled(
|
||||
self, app: Flask, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
"""Once task.completed=True in DB, subsequent GETs skip polling (no re-poll)."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.app.services.c2 import fake as fake_mod
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def _completed(self, task_display_id: int) -> C2TaskStatus:
|
||||
call_count["n"] += 1
|
||||
return C2TaskStatus(
|
||||
display_id=task_display_id,
|
||||
status="completed",
|
||||
completed=True,
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _completed)
|
||||
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
_list_tasks(client, admin_token, sim["id"]) # 1st GET — marks task completed (1 call)
|
||||
first_count = call_count["n"]
|
||||
|
||||
_list_tasks(client, admin_token, sim["id"]) # 2nd GET — task already completed, skip poll
|
||||
|
||||
# get_task should NOT have been called again on the 2nd GET.
|
||||
assert call_count["n"] == first_count, "completed task should not be re-polled"
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
assert resp.status_code == 200
|
||||
task = resp.get_json()["tasks"][0]
|
||||
assert task["completed"] is True
|
||||
|
||||
def test_redteam_can_list_tasks(
|
||||
self, client: FlaskClient, admin_token: str, redteam_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
resp = _list_tasks(client, redteam_token, sim["id"])
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListTasksErrors:
|
||||
def test_404_simulation_not_found(
|
||||
self, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
resp = _list_tasks(client, admin_token, 9999)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_403_soc_forbidden(
|
||||
self, client: FlaskClient, admin_token: str, soc_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
|
||||
resp = _list_tasks(client, soc_token, sim["id"])
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_503_no_encryption_key(
|
||||
self, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
monkeypatch.delenv("MIMIC_ENCRYPTION_KEY", raising=False)
|
||||
eng = _make_engagement(client, admin_token)
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
assert resp.status_code == 503
|
||||
|
||||
def test_404_no_c2_config(
|
||||
self, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
eng = _make_engagement(client, admin_token)
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_adapter_error_during_poll_is_tolerated(
|
||||
self, monkeypatch, client: FlaskClient, admin_token: str
|
||||
) -> None:
|
||||
"""If get_task raises C2Error during poll, the task is skipped (best-effort)."""
|
||||
from backend.app.services.c2 import fake as fake_mod
|
||||
|
||||
def _boom(self, task_display_id: int):
|
||||
raise C2Error("upstream unavailable")
|
||||
|
||||
monkeypatch.setattr(fake_mod.FakeAdapter, "get_task", _boom)
|
||||
|
||||
eng = _make_engagement(client, admin_token)
|
||||
_put_config(client, admin_token, eng["id"])
|
||||
sim = _make_sim(client, admin_token, eng["id"])
|
||||
_execute(client, admin_token, sim["id"], ["whoami"])
|
||||
|
||||
# Should still return 200 with the task (un-refreshed status).
|
||||
resp = _list_tasks(client, admin_token, sim["id"])
|
||||
assert resp.status_code == 200
|
||||
tasks = resp.get_json()["tasks"]
|
||||
assert len(tasks) == 1
|
||||
# Status is stale (not updated due to error) — still "submitted".
|
||||
assert tasks[0]["status"] == "submitted"
|
||||
124
backend/tests/test_migration_0007_c2.py
Normal file
124
backend/tests/test_migration_0007_c2.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Migration round-trip test for 0007_c2_task_mapping_applied.
|
||||
|
||||
Verifies that upgrade() adds the mapping_applied column and downgrade() removes it.
|
||||
Uses the resolved-path pattern per lessons.md Sprint 4.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
from alembic.operations import Operations
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
|
||||
|
||||
def _load_migration(name: str):
|
||||
versions_dir = Path(__file__).resolve().parent.parent / "migrations" / "versions"
|
||||
path = versions_dir / name
|
||||
spec = importlib.util.spec_from_file_location(name.removesuffix(".py"), path)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod) # type: ignore[union-attr]
|
||||
return mod
|
||||
|
||||
|
||||
def _fresh_engine_with_c2_task():
|
||||
"""In-memory SQLite with c2_task already created (as left by 0006 upgrade)."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE c2_task (
|
||||
id INTEGER PRIMARY KEY,
|
||||
simulation_id INTEGER NOT NULL,
|
||||
mythic_task_display_id INTEGER NOT NULL,
|
||||
callback_display_id INTEGER NOT NULL,
|
||||
command TEXT NOT NULL,
|
||||
params TEXT,
|
||||
status TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT 0,
|
||||
output TEXT,
|
||||
source TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
completed_at DATETIME
|
||||
)
|
||||
"""))
|
||||
return engine
|
||||
|
||||
|
||||
def _run_upgrade(engine, migration_mod):
|
||||
with engine.begin() as conn:
|
||||
ctx = MigrationContext.configure(conn)
|
||||
ops = Operations(ctx)
|
||||
ops._install_proxy() # type: ignore[attr-defined]
|
||||
try:
|
||||
migration_mod.upgrade()
|
||||
finally:
|
||||
ops._remove_proxy() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _run_downgrade(engine, migration_mod):
|
||||
with engine.begin() as conn:
|
||||
ctx = MigrationContext.configure(conn)
|
||||
ops = Operations(ctx)
|
||||
ops._install_proxy() # type: ignore[attr-defined]
|
||||
try:
|
||||
migration_mod.downgrade()
|
||||
finally:
|
||||
ops._remove_proxy() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestMigration0007Upgrade:
|
||||
def test_mapping_applied_column_added(self):
|
||||
engine = _fresh_engine_with_c2_task()
|
||||
mod = _load_migration("0007_c2_task_mapping_applied.py")
|
||||
_run_upgrade(engine, mod)
|
||||
|
||||
insp = inspect(engine)
|
||||
cols = {c["name"] for c in insp.get_columns("c2_task")}
|
||||
assert "mapping_applied" in cols
|
||||
|
||||
def test_mapping_applied_defaults_to_false(self):
|
||||
engine = _fresh_engine_with_c2_task()
|
||||
mod = _load_migration("0007_c2_task_mapping_applied.py")
|
||||
|
||||
# Insert a row before upgrading (no mapping_applied column yet).
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(
|
||||
"INSERT INTO c2_task "
|
||||
"(simulation_id, mythic_task_display_id, callback_display_id, "
|
||||
"command, status, completed, source, created_at) "
|
||||
"VALUES (1, 1000, 1, 'whoami', 'submitted', 0, 'mimic', '2026-01-01')"
|
||||
))
|
||||
|
||||
_run_upgrade(engine, mod)
|
||||
|
||||
with engine.begin() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT mapping_applied FROM c2_task WHERE id = 1")
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
# SQLite stores booleans as 0/1.
|
||||
assert row[0] == 0 or row[0] is False
|
||||
|
||||
|
||||
class TestMigration0007Downgrade:
|
||||
def test_downgrade_removes_mapping_applied(self):
|
||||
engine = _fresh_engine_with_c2_task()
|
||||
mod = _load_migration("0007_c2_task_mapping_applied.py")
|
||||
_run_upgrade(engine, mod)
|
||||
_run_downgrade(engine, mod)
|
||||
|
||||
insp = inspect(engine)
|
||||
cols = {c["name"] for c in insp.get_columns("c2_task")}
|
||||
assert "mapping_applied" not in cols
|
||||
|
||||
def test_downgrade_does_not_drop_other_columns(self):
|
||||
engine = _fresh_engine_with_c2_task()
|
||||
mod = _load_migration("0007_c2_task_mapping_applied.py")
|
||||
_run_upgrade(engine, mod)
|
||||
_run_downgrade(engine, mod)
|
||||
|
||||
insp = inspect(engine)
|
||||
cols = {c["name"] for c in insp.get_columns("c2_task")}
|
||||
assert {"id", "simulation_id", "command", "status", "completed"} <= cols
|
||||
Reference in New Issue
Block a user