diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py new file mode 100644 index 0000000..d5b5343 --- /dev/null +++ b/tests/test_snapshot.py @@ -0,0 +1,118 @@ +# tests/test_snapshot.py +"""Tests for subscribe_once() and subscribe_collect() snapshot helpers.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_ws_message(sub_id: str, data: dict, proto: str = "graphql-transport-ws") -> str: + msg_type = "next" if proto == "graphql-transport-ws" else "data" + return json.dumps({"id": sub_id, "type": msg_type, "payload": {"data": data}}) + + +def _make_ws_recv_sequence(*messages: str): + """Build an async iterator that yields strings then hangs.""" + + async def _gen(): + for m in messages: + yield m + # hang — simulates no more messages + await asyncio.Event().wait() + + return _gen() + + +@pytest.fixture +def mock_ws(): + ws = MagicMock() + ws.subprotocol = "graphql-transport-ws" + ws.send = AsyncMock() + return ws + + +@pytest.mark.asyncio +async def test_subscribe_once_returns_first_event(mock_ws): + """subscribe_once returns data from the first matching event.""" + from unraid_mcp.subscriptions.snapshot import subscribe_once + + ack = json.dumps({"type": "connection_ack"}) + data_msg = _make_ws_message("snapshot-1", {"systemMetricsCpu": {"percentTotal": 42.0}}) + mock_ws.__aiter__ = lambda s: aiter([data_msg]) + mock_ws.recv = AsyncMock(return_value=ack) + + async def aiter(items): + for item in items: + yield item + + with patch("unraid_mcp.subscriptions.snapshot.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await subscribe_once("subscription { systemMetricsCpu { percentTotal } }") + + assert result == {"systemMetricsCpu": {"percentTotal": 42.0}} + + +@pytest.mark.asyncio +async def test_subscribe_once_raises_on_graphql_error(mock_ws): + """subscribe_once raises ToolError when server returns GraphQL errors.""" + from unraid_mcp.core.exceptions import ToolError + from unraid_mcp.subscriptions.snapshot import subscribe_once + + ack = json.dumps({"type": "connection_ack"}) + error_msg = json.dumps( + { + "id": "snapshot-1", + "type": "next", + "payload": {"errors": [{"message": "Not authorized"}]}, + } + ) + + async def aiter(items): + for item in items: + yield item + + mock_ws.__aiter__ = lambda s: aiter([error_msg]) + mock_ws.recv = AsyncMock(return_value=ack) + + with patch("unraid_mcp.subscriptions.snapshot.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(ToolError, match="Not authorized"): + await subscribe_once("subscription { systemMetricsCpu { percentTotal } }") + + +@pytest.mark.asyncio +async def test_subscribe_collect_returns_multiple_events(mock_ws): + """subscribe_collect returns a list of events received within the window.""" + from unraid_mcp.subscriptions.snapshot import subscribe_collect + + ack = json.dumps({"type": "connection_ack"}) + msg1 = _make_ws_message("snapshot-1", {"notificationAdded": {"id": "1", "title": "A"}}) + msg2 = _make_ws_message("snapshot-1", {"notificationAdded": {"id": "2", "title": "B"}}) + + async def aiter(items): + for item in items: + yield item + await asyncio.sleep(10) # hang after messages + + mock_ws.__aiter__ = lambda s: aiter([msg1, msg2]) + mock_ws.recv = AsyncMock(return_value=ack) + + with patch("unraid_mcp.subscriptions.snapshot.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await subscribe_collect( + "subscription { notificationAdded { id title } }", + collect_for=0.1, + ) + + assert len(result) == 2 + assert result[0]["notificationAdded"]["id"] == "1" diff --git a/unraid_mcp/subscriptions/snapshot.py b/unraid_mcp/subscriptions/snapshot.py new file mode 100644 index 0000000..839a9cb --- /dev/null +++ b/unraid_mcp/subscriptions/snapshot.py @@ -0,0 +1,165 @@ +"""One-shot GraphQL subscription helpers for MCP tool snapshot actions. + +`subscribe_once(query, variables, timeout)` — connect, subscribe, return the +first event's data, then disconnect. + +`subscribe_collect(query, variables, collect_for, timeout)` — connect, +subscribe, collect all events for `collect_for` seconds, return the list. + +Neither function maintains a persistent connection — they open and close a +WebSocket per call. This is intentional: MCP tools are request-response. +Use the SubscriptionManager for long-lived monitoring resources. +""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any + +import websockets +from websockets.typing import Subprotocol + +from ..config.logging import logger +from ..config.settings import UNRAID_API_KEY +from ..core.exceptions import ToolError +from .utils import build_ws_ssl_context, build_ws_url + + +async def subscribe_once( + query: str, + variables: dict[str, Any] | None = None, + timeout: float = 10.0, +) -> dict[str, Any]: + """Open a WebSocket subscription, receive the first data event, close, return it. + + Raises ToolError on auth failure, GraphQL errors, or timeout. + """ + ws_url = build_ws_url() + ssl_context = build_ws_ssl_context(ws_url) + + async with websockets.connect( + ws_url, + subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")], + open_timeout=timeout, + ping_interval=20, + ping_timeout=10, + ssl=ssl_context, + ) as ws: + proto = ws.subprotocol or "graphql-transport-ws" + sub_id = "snapshot-1" + + # Handshake + init: dict[str, Any] = {"type": "connection_init"} + if UNRAID_API_KEY: + init["payload"] = {"headers": {"X-API-Key": UNRAID_API_KEY}} + await ws.send(json.dumps(init)) + + raw = await asyncio.wait_for(ws.recv(), timeout=timeout) + ack = json.loads(raw) + if ack.get("type") == "connection_error": + raise ToolError(f"Subscription auth failed: {ack.get('payload')}") + if ack.get("type") != "connection_ack": + raise ToolError(f"Unexpected handshake response: {ack.get('type')}") + + # Subscribe + start_type = "subscribe" if proto == "graphql-transport-ws" else "start" + await ws.send( + json.dumps( + { + "id": sub_id, + "type": start_type, + "payload": {"query": query, "variables": variables or {}}, + } + ) + ) + + # Await first matching data event + expected_type = "next" if proto == "graphql-transport-ws" else "data" + deadline = asyncio.get_event_loop().time() + timeout + + async for raw_msg in ws: + if asyncio.get_event_loop().time() >= deadline: + raise ToolError(f"Subscription timed out after {timeout:.0f}s") + + msg = json.loads(raw_msg) + if msg.get("type") == "ping": + await ws.send(json.dumps({"type": "pong"})) + continue + if msg.get("type") == expected_type and msg.get("id") == sub_id: + payload = msg.get("payload", {}) + if errors := payload.get("errors"): + msgs = "; ".join(e.get("message", str(e)) for e in errors) + raise ToolError(f"Subscription errors: {msgs}") + if data := payload.get("data"): + return data + elif msg.get("type") == "error" and msg.get("id") == sub_id: + raise ToolError(f"Subscription error: {msg.get('payload')}") + + raise ToolError("WebSocket closed before receiving subscription data") + + +async def subscribe_collect( + query: str, + variables: dict[str, Any] | None = None, + collect_for: float = 5.0, + timeout: float = 10.0, +) -> list[dict[str, Any]]: + """Open a subscription, collect events for `collect_for` seconds, close, return list. + + Returns an empty list if no events arrive within the window. + Always closes the connection after the window expires. + """ + ws_url = build_ws_url() + ssl_context = build_ws_ssl_context(ws_url) + events: list[dict[str, Any]] = [] + + async with websockets.connect( + ws_url, + subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")], + open_timeout=timeout, + ping_interval=20, + ping_timeout=10, + ssl=ssl_context, + ) as ws: + proto = ws.subprotocol or "graphql-transport-ws" + sub_id = "snapshot-1" + + init: dict[str, Any] = {"type": "connection_init"} + if UNRAID_API_KEY: + init["payload"] = {"headers": {"X-API-Key": UNRAID_API_KEY}} + await ws.send(json.dumps(init)) + + raw = await asyncio.wait_for(ws.recv(), timeout=timeout) + ack = json.loads(raw) + if ack.get("type") not in ("connection_ack",): + raise ToolError(f"Subscription handshake failed: {ack.get('type')}") + + start_type = "subscribe" if proto == "graphql-transport-ws" else "start" + await ws.send( + json.dumps( + { + "id": sub_id, + "type": start_type, + "payload": {"query": query, "variables": variables or {}}, + } + ) + ) + + expected_type = "next" if proto == "graphql-transport-ws" else "data" + collect_deadline = asyncio.get_event_loop().time() + collect_for + + async for raw_msg in ws: + if asyncio.get_event_loop().time() >= collect_deadline: + break + msg = json.loads(raw_msg) + if msg.get("type") == "ping": + await ws.send(json.dumps({"type": "pong"})) + continue + if msg.get("type") == expected_type and msg.get("id") == sub_id: + payload = msg.get("payload", {}) + if data := payload.get("data"): + events.append(data) + + logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window") + return events