mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-01 16:04:24 -08:00
Add comprehensive test coverage beyond unit tests: - Schema validation (93 tests): Validate all GraphQL queries/mutations against extracted Unraid API schema - HTTP layer (88 tests): Test request construction, timeouts, and error handling at httpx level - Subscriptions (55 tests): WebSocket lifecycle, reconnection, and protocol validation - Safety audit (39 tests): Enforce destructive action confirmation requirements Total test count increased from 210 to 485 (130% increase), all passing in 5.91s. New dependencies: - graphql-core>=3.2.0 for schema validation - respx>=0.22.0 for HTTP layer mocking Files created: - docs/unraid-schema.graphql (150-type GraphQL schema) - tests/schema/test_query_validation.py - tests/http_layer/test_request_construction.py - tests/integration/test_subscriptions.py - tests/safety/test_destructive_guards.py Co-authored-by: Claude <claude@anthropic.com>
891 lines
32 KiB
Python
891 lines
32 KiB
Python
"""Integration tests for WebSocket subscription lifecycle and reconnection logic.
|
|
|
|
These tests validate the SubscriptionManager's connection lifecycle,
|
|
reconnection with exponential backoff, protocol handling, and resource
|
|
data management without requiring a live Unraid server.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import websockets.exceptions
|
|
|
|
from unraid_mcp.subscriptions.manager import SubscriptionManager
|
|
|
|
pytestmark = pytest.mark.integration
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class FakeWebSocket:
|
|
"""Minimal fake WebSocket that supports both recv() and async-for iteration.
|
|
|
|
The manager calls ``recv()`` once for the connection_ack, then enters
|
|
``async for message in websocket:`` for the data stream. This class
|
|
tracks a shared message queue so both paths draw from the same list.
|
|
|
|
When messages are exhausted, iteration ends cleanly via StopAsyncIteration
|
|
(which terminates ``async for``), and ``recv()`` raises ConnectionClosed
|
|
so the manager treats it as a normal disconnection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
messages: list[dict[str, Any] | str],
|
|
subprotocol: str = "graphql-transport-ws",
|
|
) -> None:
|
|
self.subprotocol = subprotocol
|
|
self._messages = [
|
|
json.dumps(m) if isinstance(m, dict) else m for m in messages
|
|
]
|
|
self._index = 0
|
|
self.send = AsyncMock()
|
|
|
|
async def recv(self) -> str:
|
|
if self._index >= len(self._messages):
|
|
# Simulate normal connection close when messages exhausted
|
|
from websockets.frames import Close
|
|
|
|
raise websockets.exceptions.ConnectionClosed(
|
|
Close(1000, "normal closure"), None
|
|
)
|
|
msg = self._messages[self._index]
|
|
self._index += 1
|
|
return msg
|
|
|
|
def __aiter__(self) -> "FakeWebSocket":
|
|
return self
|
|
|
|
async def __anext__(self) -> str:
|
|
if self._index >= len(self._messages):
|
|
raise StopAsyncIteration
|
|
msg = self._messages[self._index]
|
|
self._index += 1
|
|
return msg
|
|
|
|
|
|
def _ws_context(ws: FakeWebSocket) -> MagicMock:
|
|
"""Wrap a FakeWebSocket so ``async with websockets.connect(...) as ws:`` works."""
|
|
ctx = MagicMock()
|
|
ctx.__aenter__ = AsyncMock(return_value=ws)
|
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
return ctx
|
|
|
|
|
|
SAMPLE_QUERY = "subscription { test { value } }"
|
|
|
|
# Shared patch targets
|
|
_WS_CONNECT = "unraid_mcp.subscriptions.manager.websockets.connect"
|
|
_API_URL = "unraid_mcp.subscriptions.manager.UNRAID_API_URL"
|
|
_API_KEY = "unraid_mcp.subscriptions.manager.UNRAID_API_KEY"
|
|
_SSL_CTX = "unraid_mcp.subscriptions.manager.build_ws_ssl_context"
|
|
_SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SubscriptionManager Initialisation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSubscriptionManagerInit:
|
|
|
|
def test_default_state(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.active_subscriptions == {}
|
|
assert mgr.resource_data == {}
|
|
assert mgr.websocket is None
|
|
|
|
def test_default_auto_start_enabled(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.auto_start_enabled is True
|
|
|
|
@patch.dict("os.environ", {"UNRAID_AUTO_START_SUBSCRIPTIONS": "false"})
|
|
def test_auto_start_disabled_via_env(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.auto_start_enabled is False
|
|
|
|
def test_default_max_reconnect_attempts(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.max_reconnect_attempts == 10
|
|
|
|
@patch.dict("os.environ", {"UNRAID_MAX_RECONNECT_ATTEMPTS": "5"})
|
|
def test_custom_max_reconnect_attempts(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.max_reconnect_attempts == 5
|
|
|
|
def test_subscription_configs_contain_log_file(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert "logFileSubscription" in mgr.subscription_configs
|
|
|
|
def test_log_file_subscription_not_auto_start(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
cfg = mgr.subscription_configs["logFileSubscription"]
|
|
assert cfg.get("auto_start") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connection Lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConnectionLifecycle:
|
|
|
|
async def test_start_subscription_creates_task(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
ws = FakeWebSocket([{"type": "connection_ack"}])
|
|
ctx = _ws_context(ws)
|
|
|
|
with (
|
|
patch(_WS_CONNECT, return_value=ctx),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "test-key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
):
|
|
await mgr.start_subscription("test_sub", SAMPLE_QUERY)
|
|
assert "test_sub" in mgr.active_subscriptions
|
|
assert isinstance(mgr.active_subscriptions["test_sub"], asyncio.Task)
|
|
await mgr.stop_subscription("test_sub")
|
|
|
|
async def test_duplicate_start_is_noop(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
ws = FakeWebSocket([{"type": "connection_ack"}])
|
|
ctx = _ws_context(ws)
|
|
|
|
with (
|
|
patch(_WS_CONNECT, return_value=ctx),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "test-key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
):
|
|
await mgr.start_subscription("test_sub", SAMPLE_QUERY)
|
|
first_task = mgr.active_subscriptions["test_sub"]
|
|
await mgr.start_subscription("test_sub", SAMPLE_QUERY)
|
|
assert mgr.active_subscriptions["test_sub"] is first_task
|
|
await mgr.stop_subscription("test_sub")
|
|
|
|
async def test_stop_subscription_cancels_task(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
ws = FakeWebSocket([{"type": "connection_ack"}])
|
|
ctx = _ws_context(ws)
|
|
|
|
with (
|
|
patch(_WS_CONNECT, return_value=ctx),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "test-key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
):
|
|
await mgr.start_subscription("test_sub", SAMPLE_QUERY)
|
|
assert "test_sub" in mgr.active_subscriptions
|
|
await mgr.stop_subscription("test_sub")
|
|
assert "test_sub" not in mgr.active_subscriptions
|
|
assert mgr.connection_states.get("test_sub") == "stopped"
|
|
|
|
async def test_stop_nonexistent_subscription_is_safe(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
await mgr.stop_subscription("nonexistent")
|
|
|
|
async def test_connection_state_transitions(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
ws = FakeWebSocket([{"type": "connection_ack"}])
|
|
ctx = _ws_context(ws)
|
|
|
|
with (
|
|
patch(_WS_CONNECT, return_value=ctx),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "test-key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
):
|
|
await mgr.start_subscription("test_sub", SAMPLE_QUERY)
|
|
assert mgr.connection_states["test_sub"] == "active"
|
|
await mgr.stop_subscription("test_sub")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Protocol Handling (via _subscription_loop)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _loop_patches(
|
|
ws: FakeWebSocket,
|
|
api_key: str = "test-key",
|
|
) -> tuple:
|
|
"""Patches for tests that call ``_subscription_loop`` directly.
|
|
|
|
Uses a connect mock that succeeds once then fails, plus a mocked
|
|
asyncio.sleep to prevent real delays.
|
|
"""
|
|
ctx = _ws_context(ws)
|
|
call_count = 0
|
|
|
|
def _connect_side_effect(*_a: Any, **_kw: Any) -> MagicMock:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return ctx
|
|
raise ConnectionRefusedError("no more test connections")
|
|
|
|
return (
|
|
patch(_WS_CONNECT, side_effect=_connect_side_effect),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, api_key),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, new_callable=AsyncMock),
|
|
)
|
|
|
|
|
|
class TestProtocolHandling:
|
|
|
|
async def test_connection_init_sends_auth(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "next", "id": "test_sub", "payload": {"data": {"v": 1}}},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws, api_key="my-secret-key")
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
first_send = ws.send.call_args_list[0]
|
|
init_msg = json.loads(first_send[0][0])
|
|
assert init_msg["type"] == "connection_init"
|
|
assert init_msg["payload"]["headers"]["X-API-Key"] == "my-secret-key"
|
|
|
|
async def test_subscribe_uses_subscribe_type_for_transport_ws(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket(
|
|
[{"type": "connection_ack"}, {"type": "complete", "id": "test_sub"}],
|
|
subprotocol="graphql-transport-ws",
|
|
)
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
sub_send = ws.send.call_args_list[1]
|
|
sub_msg = json.loads(sub_send[0][0])
|
|
assert sub_msg["type"] == "subscribe"
|
|
assert sub_msg["id"] == "test_sub"
|
|
|
|
async def test_subscribe_uses_start_type_for_graphql_ws(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket(
|
|
[{"type": "connection_ack"}, {"type": "complete", "id": "test_sub"}],
|
|
subprotocol="graphql-ws",
|
|
)
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
sub_send = ws.send.call_args_list[1]
|
|
sub_msg = json.loads(sub_send[0][0])
|
|
assert sub_msg["type"] == "start"
|
|
|
|
async def test_connection_error_sets_auth_failed(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_error", "payload": {"message": "Invalid API key"}},
|
|
])
|
|
p = _loop_patches(ws, api_key="bad-key")
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert mgr.connection_states["test_sub"] == "auth_failed"
|
|
assert "Authentication error" in mgr.last_error["test_sub"]
|
|
|
|
async def test_no_api_key_omits_payload(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws, api_key="")
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
first_send = ws.send.call_args_list[0]
|
|
init_msg = json.loads(first_send[0][0])
|
|
assert init_msg["type"] == "connection_init"
|
|
assert "payload" not in init_msg
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data Reception
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDataReception:
|
|
|
|
async def test_next_message_stores_resource_data(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket(
|
|
[
|
|
{"type": "connection_ack"},
|
|
{"type": "next", "id": "test_sub", "payload": {"data": {"test": {"value": 42}}}},
|
|
{"type": "complete", "id": "test_sub"},
|
|
],
|
|
subprotocol="graphql-transport-ws",
|
|
)
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert "test_sub" in mgr.resource_data
|
|
assert mgr.resource_data["test_sub"].data == {"test": {"value": 42}}
|
|
assert mgr.resource_data["test_sub"].subscription_type == "test_sub"
|
|
|
|
async def test_data_message_for_legacy_protocol(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket(
|
|
[
|
|
{"type": "connection_ack"},
|
|
{"type": "data", "id": "test_sub", "payload": {"data": {"legacy": True}}},
|
|
{"type": "complete", "id": "test_sub"},
|
|
],
|
|
subprotocol="graphql-ws",
|
|
)
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert "test_sub" in mgr.resource_data
|
|
assert mgr.resource_data["test_sub"].data == {"legacy": True}
|
|
|
|
async def test_graphql_errors_tracked_in_last_error(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket(
|
|
[
|
|
{"type": "connection_ack"},
|
|
{"type": "next", "id": "test_sub", "payload": {"errors": [{"message": "bad"}]}},
|
|
{"type": "complete", "id": "test_sub"},
|
|
],
|
|
subprotocol="graphql-transport-ws",
|
|
)
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
# The last_error may be overwritten by a subsequent reconnection error,
|
|
# so check the resource_data wasn't stored (errors in payload means no data)
|
|
assert "test_sub" not in mgr.resource_data
|
|
|
|
async def test_ping_receives_pong_response(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "ping"},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
pong_sent = any(
|
|
json.loads(call[0][0]).get("type") == "pong"
|
|
for call in ws.send.call_args_list
|
|
)
|
|
assert pong_sent, "Expected pong response to be sent"
|
|
|
|
async def test_error_message_sets_error_state(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "error", "id": "test_sub", "payload": {"message": "bad query"}},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
# Verify the error was recorded at some point by checking resource_data
|
|
# was not stored (error messages don't produce data)
|
|
assert "test_sub" not in mgr.resource_data
|
|
|
|
async def test_complete_message_breaks_inner_loop(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
# complete message was processed (test finished, loop terminated)
|
|
assert "test_sub" not in mgr.resource_data
|
|
|
|
async def test_mismatched_id_ignored(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket(
|
|
[
|
|
{"type": "connection_ack"},
|
|
{"type": "next", "id": "other_sub", "payload": {"data": {"wrong": True}}},
|
|
{"type": "complete", "id": "test_sub"},
|
|
],
|
|
subprotocol="graphql-transport-ws",
|
|
)
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert "test_sub" not in mgr.resource_data
|
|
|
|
async def test_keepalive_messages_handled(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "ka"},
|
|
{"type": "pong"},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
async def test_invalid_json_message_handled(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
"not valid json {{{",
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reconnection and Backoff
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReconnection:
|
|
|
|
async def test_max_retries_exceeded_stops_loop(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
with (
|
|
patch(_WS_CONNECT, side_effect=ConnectionRefusedError("refused")),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, new_callable=AsyncMock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert mgr.connection_states["test_sub"] == "max_retries_exceeded"
|
|
assert mgr.reconnect_attempts["test_sub"] > mgr.max_reconnect_attempts
|
|
|
|
async def test_backoff_delay_increases(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 3
|
|
|
|
sleep_mock = AsyncMock()
|
|
|
|
with (
|
|
patch(_WS_CONNECT, side_effect=ConnectionRefusedError("refused")),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, sleep_mock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
delays = [call[0][0] for call in sleep_mock.call_args_list]
|
|
assert len(delays) >= 2
|
|
for i in range(1, len(delays)):
|
|
assert delays[i] > delays[i - 1], (
|
|
f"Delay should increase: {delays[i]} > {delays[i - 1]}"
|
|
)
|
|
|
|
async def test_backoff_capped_at_max(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 50
|
|
|
|
sleep_mock = AsyncMock()
|
|
|
|
with (
|
|
patch(_WS_CONNECT, side_effect=ConnectionRefusedError("refused")),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, sleep_mock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
delays = [call[0][0] for call in sleep_mock.call_args_list]
|
|
for d in delays:
|
|
assert d <= 300, f"Delay {d} exceeds max of 300 seconds"
|
|
|
|
async def test_successful_connection_resets_retry_count(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 10
|
|
mgr.reconnect_attempts["test_sub"] = 5
|
|
|
|
ws = FakeWebSocket([
|
|
{"type": "connection_ack"},
|
|
{"type": "complete", "id": "test_sub"},
|
|
])
|
|
p = _loop_patches(ws)
|
|
with p[0], p[1], p[2], p[3], p[4]:
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
# After successful connection, attempts reset to 0 internally.
|
|
# The loop then reconnects, fails, and increments. But since we
|
|
# started at 5, the key check is that we didn't immediately bail.
|
|
# Verify some messages were processed (connection was established).
|
|
assert ws.send.call_count >= 2 # connection_init + subscribe
|
|
|
|
async def test_invalid_uri_does_not_retry(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 5
|
|
|
|
sleep_mock = AsyncMock()
|
|
|
|
with (
|
|
patch(
|
|
_WS_CONNECT,
|
|
side_effect=websockets.exceptions.InvalidURI("bad://url", "Invalid URI"),
|
|
),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, sleep_mock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert mgr.connection_states["test_sub"] == "invalid_uri"
|
|
sleep_mock.assert_not_called()
|
|
|
|
async def test_timeout_error_triggers_reconnect(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
sleep_mock = AsyncMock()
|
|
|
|
with (
|
|
patch(_WS_CONNECT, side_effect=TimeoutError("connection timeout")),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, sleep_mock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert mgr.last_error["test_sub"] == "Connection or authentication timeout"
|
|
assert sleep_mock.call_count >= 1
|
|
|
|
async def test_connection_closed_triggers_reconnect(self) -> None:
|
|
from websockets.frames import Close
|
|
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 2
|
|
|
|
sleep_mock = AsyncMock()
|
|
|
|
with (
|
|
patch(
|
|
_WS_CONNECT,
|
|
side_effect=websockets.exceptions.ConnectionClosed(
|
|
Close(1006, "abnormal"), None
|
|
),
|
|
),
|
|
patch(_API_URL, "https://test.local"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, sleep_mock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
assert "WebSocket connection closed" in mgr.last_error.get("test_sub", "")
|
|
assert mgr.connection_states["test_sub"] in ("disconnected", "max_retries_exceeded")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WebSocket URL Construction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWebSocketURLConstruction:
|
|
|
|
async def test_https_converted_to_wss(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 1
|
|
|
|
connect_mock = MagicMock(side_effect=ConnectionRefusedError("test"))
|
|
|
|
with (
|
|
patch(_WS_CONNECT, connect_mock),
|
|
patch(_API_URL, "https://myserver.local:31337"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, new_callable=AsyncMock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
url_arg = connect_mock.call_args[0][0]
|
|
assert url_arg.startswith("wss://")
|
|
assert url_arg.endswith("/graphql")
|
|
|
|
async def test_http_converted_to_ws(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 1
|
|
|
|
connect_mock = MagicMock(side_effect=ConnectionRefusedError("test"))
|
|
|
|
with (
|
|
patch(_WS_CONNECT, connect_mock),
|
|
patch(_API_URL, "http://192.168.1.100:8080"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, new_callable=AsyncMock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
url_arg = connect_mock.call_args[0][0]
|
|
assert url_arg.startswith("ws://")
|
|
assert url_arg.endswith("/graphql")
|
|
|
|
async def test_no_api_url_raises_value_error(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 1
|
|
|
|
with (
|
|
patch(_API_URL, ""),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SLEEP, new_callable=AsyncMock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
assert mgr.connection_states["test_sub"] in ("error", "max_retries_exceeded")
|
|
|
|
async def test_graphql_suffix_not_duplicated(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.max_reconnect_attempts = 1
|
|
|
|
connect_mock = MagicMock(side_effect=ConnectionRefusedError("test"))
|
|
|
|
with (
|
|
patch(_WS_CONNECT, connect_mock),
|
|
patch(_API_URL, "https://myserver.local/graphql"),
|
|
patch(_API_KEY, "key"),
|
|
patch(_SSL_CTX, return_value=None),
|
|
patch(_SLEEP, new_callable=AsyncMock),
|
|
):
|
|
await mgr._subscription_loop("test_sub", SAMPLE_QUERY, {})
|
|
|
|
url_arg = connect_mock.call_args[0][0]
|
|
assert url_arg == "wss://myserver.local/graphql"
|
|
assert "/graphql/graphql" not in url_arg
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Resource Data Access
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResourceData:
|
|
|
|
def test_get_resource_data_returns_none_when_empty(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.get_resource_data("nonexistent") is None
|
|
|
|
def test_get_resource_data_returns_stored_data(self) -> None:
|
|
from unraid_mcp.core.types import SubscriptionData
|
|
|
|
mgr = SubscriptionManager()
|
|
mgr.resource_data["test"] = SubscriptionData(
|
|
data={"key": "value"},
|
|
last_updated=datetime.now(),
|
|
subscription_type="test",
|
|
)
|
|
result = mgr.get_resource_data("test")
|
|
assert result == {"key": "value"}
|
|
|
|
def test_list_active_subscriptions_empty(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
assert mgr.list_active_subscriptions() == []
|
|
|
|
def test_list_active_subscriptions_returns_names(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.active_subscriptions["sub_a"] = MagicMock()
|
|
mgr.active_subscriptions["sub_b"] = MagicMock()
|
|
result = mgr.list_active_subscriptions()
|
|
assert sorted(result) == ["sub_a", "sub_b"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Subscription Status Diagnostics
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSubscriptionStatus:
|
|
|
|
def test_status_includes_all_configured_subscriptions(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
status = mgr.get_subscription_status()
|
|
for name in mgr.subscription_configs:
|
|
assert name in status
|
|
|
|
def test_status_default_connection_state(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
status = mgr.get_subscription_status()
|
|
for sub_status in status.values():
|
|
assert sub_status["runtime"]["connection_state"] == "not_started"
|
|
|
|
def test_status_shows_active_flag(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.active_subscriptions["logFileSubscription"] = MagicMock()
|
|
status = mgr.get_subscription_status()
|
|
assert status["logFileSubscription"]["runtime"]["active"] is True
|
|
|
|
def test_status_shows_data_availability(self) -> None:
|
|
from unraid_mcp.core.types import SubscriptionData
|
|
|
|
mgr = SubscriptionManager()
|
|
mgr.resource_data["logFileSubscription"] = SubscriptionData(
|
|
data={"log": "content"},
|
|
last_updated=datetime.now(),
|
|
subscription_type="logFileSubscription",
|
|
)
|
|
status = mgr.get_subscription_status()
|
|
assert status["logFileSubscription"]["data"]["available"] is True
|
|
|
|
def test_status_shows_error_info(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.last_error["logFileSubscription"] = "Test error message"
|
|
status = mgr.get_subscription_status()
|
|
assert status["logFileSubscription"]["runtime"]["last_error"] == "Test error message"
|
|
|
|
def test_status_reconnect_attempts_tracked(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.reconnect_attempts["logFileSubscription"] = 3
|
|
status = mgr.get_subscription_status()
|
|
assert status["logFileSubscription"]["runtime"]["reconnect_attempts"] == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auto-Start
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAutoStart:
|
|
|
|
async def test_auto_start_disabled_skips_all(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.auto_start_enabled = False
|
|
await mgr.auto_start_all_subscriptions()
|
|
assert mgr.active_subscriptions == {}
|
|
|
|
async def test_auto_start_only_starts_marked_subscriptions(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
with patch.object(mgr, "start_subscription", new_callable=AsyncMock) as mock_start:
|
|
await mgr.auto_start_all_subscriptions()
|
|
mock_start.assert_not_called()
|
|
|
|
async def test_auto_start_handles_failure_gracefully(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.subscription_configs["test_auto"] = {
|
|
"query": "subscription { test }",
|
|
"resource": "unraid://test",
|
|
"description": "Test auto-start",
|
|
"auto_start": True,
|
|
}
|
|
|
|
with patch.object(
|
|
mgr, "start_subscription", new_callable=AsyncMock, side_effect=RuntimeError("fail")
|
|
):
|
|
await mgr.auto_start_all_subscriptions()
|
|
assert "fail" in mgr.last_error.get("test_auto", "")
|
|
|
|
async def test_auto_start_calls_start_for_marked(self) -> None:
|
|
mgr = SubscriptionManager()
|
|
mgr.subscription_configs["auto_sub"] = {
|
|
"query": "subscription { auto }",
|
|
"resource": "unraid://auto",
|
|
"description": "Auto sub",
|
|
"auto_start": True,
|
|
}
|
|
|
|
with patch.object(mgr, "start_subscription", new_callable=AsyncMock) as mock_start:
|
|
await mgr.auto_start_all_subscriptions()
|
|
mock_start.assert_called_once_with("auto_sub", "subscription { auto }")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SSL Context (via utils)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSSLContext:
|
|
|
|
def test_non_wss_returns_none(self) -> None:
|
|
from unraid_mcp.subscriptions.utils import build_ws_ssl_context
|
|
|
|
assert build_ws_ssl_context("ws://localhost:8080/graphql") is None
|
|
|
|
def test_wss_with_verify_true_returns_default_context(self) -> None:
|
|
import ssl
|
|
|
|
from unraid_mcp.subscriptions.utils import build_ws_ssl_context
|
|
|
|
with patch("unraid_mcp.subscriptions.utils.UNRAID_VERIFY_SSL", True):
|
|
ctx = build_ws_ssl_context("wss://test.local/graphql")
|
|
assert isinstance(ctx, ssl.SSLContext)
|
|
assert ctx.check_hostname is True
|
|
|
|
def test_wss_with_verify_false_disables_verification(self) -> None:
|
|
import ssl
|
|
|
|
from unraid_mcp.subscriptions.utils import build_ws_ssl_context
|
|
|
|
with patch("unraid_mcp.subscriptions.utils.UNRAID_VERIFY_SSL", False):
|
|
ctx = build_ws_ssl_context("wss://test.local/graphql")
|
|
assert isinstance(ctx, ssl.SSLContext)
|
|
assert ctx.check_hostname is False
|
|
assert ctx.verify_mode == ssl.CERT_NONE
|
|
|
|
def test_wss_with_ca_bundle_path(self) -> None:
|
|
from unraid_mcp.subscriptions.utils import build_ws_ssl_context
|
|
|
|
with (
|
|
patch("unraid_mcp.subscriptions.utils.UNRAID_VERIFY_SSL", "/path/to/ca-bundle.crt"),
|
|
patch("ssl.create_default_context") as mock_ctx,
|
|
):
|
|
build_ws_ssl_context("wss://test.local/graphql")
|
|
mock_ctx.assert_called_once_with(cafile="/path/to/ca-bundle.crt")
|