fix(subscriptions): use x-api-key connectionParams format for WebSocket auth

The Unraid graphql-ws server expects the API key directly in connectionParams
as `x-api-key`, not nested under `headers`. The old format caused the server
to fall through to cookie auth and crash on `undefined.csrf_token`.

Fixed in snapshot.py (×2), manager.py, diagnostics.py, and updated the
integration test assertion to match the correct payload shape.
This commit is contained in:
Jacob Magar
2026-03-15 22:56:58 -04:00
parent 06368ce156
commit c37d4b1c5a
9 changed files with 1076 additions and 70 deletions

View File

@@ -43,9 +43,7 @@ class FakeWebSocket:
subprotocol: str = "graphql-transport-ws",
) -> None:
self.subprotocol = subprotocol
self._messages = [
json.dumps(m) if isinstance(m, dict) else m for m in messages
]
self._messages = [json.dumps(m) if isinstance(m, dict) else m for m in messages]
self._index = 0
self.send = AsyncMock()
@@ -54,9 +52,7 @@ class FakeWebSocket:
# Simulate normal connection close when messages exhausted
from websockets.frames import Close
raise websockets.exceptions.ConnectionClosed(
Close(1000, "normal closure"), None
)
raise websockets.exceptions.ConnectionClosed(Close(1000, "normal closure"), None)
msg = self._messages[self._index]
self._index += 1
return msg
@@ -96,7 +92,6 @@ _SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep"
class TestSubscriptionManagerInit:
def test_default_state(self) -> None:
mgr = SubscriptionManager()
assert mgr.active_subscriptions == {}
@@ -137,7 +132,6 @@ class TestSubscriptionManagerInit:
class TestConnectionLifecycle:
async def test_start_subscription_creates_task(self) -> None:
mgr = SubscriptionManager()
ws = FakeWebSocket([{"type": "connection_ack"}])
@@ -242,16 +236,17 @@ def _loop_patches(
class TestProtocolHandling:
async def test_connection_init_sends_auth(self) -> None:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "next", "id": "test_sub", "payload": {"data": {"v": 1}}},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "next", "id": "test_sub", "payload": {"data": {"v": 1}}},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws, api_key="my-secret-key")
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -259,7 +254,7 @@ class TestProtocolHandling:
first_send = ws.send.call_args_list[0]
init_msg = json.loads(first_send[0][0])
assert init_msg["type"] == "connection_init"
assert init_msg["payload"]["headers"]["X-API-Key"] == "my-secret-key"
assert init_msg["payload"]["x-api-key"] == "my-secret-key"
async def test_subscribe_uses_subscribe_type_for_transport_ws(self) -> None:
mgr = SubscriptionManager()
@@ -298,9 +293,11 @@ class TestProtocolHandling:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_error", "payload": {"message": "Invalid API key"}},
])
ws = FakeWebSocket(
[
{"type": "connection_error", "payload": {"message": "Invalid API key"}},
]
)
p = _loop_patches(ws, api_key="bad-key")
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -312,10 +309,12 @@ class TestProtocolHandling:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws, api_key="")
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -332,7 +331,6 @@ class TestProtocolHandling:
class TestDataReception:
async def test_next_message_stores_resource_data(self) -> None:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
@@ -396,18 +394,19 @@ class TestDataReception:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "ping"},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "ping"},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws)
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
pong_sent = any(
json.loads(call[0][0]).get("type") == "pong"
for call in ws.send.call_args_list
json.loads(call[0][0]).get("type") == "pong" for call in ws.send.call_args_list
)
assert pong_sent, "Expected pong response to be sent"
@@ -415,11 +414,13 @@ class TestDataReception:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "error", "id": "test_sub", "payload": {"message": "bad query"}},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "error", "id": "test_sub", "payload": {"message": "bad query"}},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws)
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -432,10 +433,12 @@ class TestDataReception:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws)
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -465,12 +468,14 @@ class TestDataReception:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "ka"},
{"type": "pong"},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "ka"},
{"type": "pong"},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws)
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -479,11 +484,13 @@ class TestDataReception:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
ws = FakeWebSocket([
{"type": "connection_ack"},
"not valid json {{{",
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
"not valid json {{{",
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws)
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -495,7 +502,6 @@ class TestDataReception:
class TestReconnection:
async def test_max_retries_exceeded_stops_loop(self) -> None:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 2
@@ -558,10 +564,12 @@ class TestReconnection:
mgr.max_reconnect_attempts = 10
mgr.reconnect_attempts["test_sub"] = 5
ws = FakeWebSocket([
{"type": "connection_ack"},
{"type": "complete", "id": "test_sub"},
])
ws = FakeWebSocket(
[
{"type": "connection_ack"},
{"type": "complete", "id": "test_sub"},
]
)
p = _loop_patches(ws)
with p[0], p[1], p[2], p[3], p[4]:
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
@@ -622,9 +630,7 @@ class TestReconnection:
with (
patch(
_WS_CONNECT,
side_effect=websockets.exceptions.ConnectionClosed(
Close(1006, "abnormal"), None
),
side_effect=websockets.exceptions.ConnectionClosed(Close(1006, "abnormal"), None),
),
patch(_API_URL, "https://test.local"),
patch(_API_KEY, "key"),
@@ -643,7 +649,6 @@ class TestReconnection:
class TestWebSocketURLConstruction:
async def test_https_converted_to_wss(self) -> None:
mgr = SubscriptionManager()
mgr.max_reconnect_attempts = 1
@@ -720,7 +725,6 @@ class TestWebSocketURLConstruction:
class TestResourceData:
async def test_get_resource_data_returns_none_when_empty(self) -> None:
mgr = SubscriptionManager()
assert await mgr.get_resource_data("nonexistent") is None
@@ -755,7 +759,6 @@ class TestResourceData:
class TestSubscriptionStatus:
async def test_status_includes_all_configured_subscriptions(self) -> None:
mgr = SubscriptionManager()
status = await mgr.get_subscription_status()
@@ -805,7 +808,6 @@ class TestSubscriptionStatus:
class TestAutoStart:
async def test_auto_start_disabled_skips_all(self) -> None:
mgr = SubscriptionManager()
mgr.auto_start_enabled = False
@@ -853,7 +855,6 @@ class TestAutoStart:
class TestSSLContext:
def test_non_wss_returns_none(self) -> None:
from unraid_mcp.subscriptions.utils import build_ws_ssl_context