mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 12:39:24 -07:00
feat(subscriptions): add subscribe_once and subscribe_collect snapshot helpers
This commit is contained in:
118
tests/test_snapshot.py
Normal file
118
tests/test_snapshot.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
# tests/test_snapshot.py
|
||||||
|
"""Tests for subscribe_once() and subscribe_collect() snapshot helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_message(sub_id: str, data: dict, proto: str = "graphql-transport-ws") -> str:
|
||||||
|
msg_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||||
|
return json.dumps({"id": sub_id, "type": msg_type, "payload": {"data": data}})
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws_recv_sequence(*messages: str):
|
||||||
|
"""Build an async iterator that yields strings then hangs."""
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
for m in messages:
|
||||||
|
yield m
|
||||||
|
# hang — simulates no more messages
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
return _gen()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ws():
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.subprotocol = "graphql-transport-ws"
|
||||||
|
ws.send = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_once_returns_first_event(mock_ws):
|
||||||
|
"""subscribe_once returns data from the first matching event."""
|
||||||
|
from unraid_mcp.subscriptions.snapshot import subscribe_once
|
||||||
|
|
||||||
|
ack = json.dumps({"type": "connection_ack"})
|
||||||
|
data_msg = _make_ws_message("snapshot-1", {"systemMetricsCpu": {"percentTotal": 42.0}})
|
||||||
|
mock_ws.__aiter__ = lambda s: aiter([data_msg])
|
||||||
|
mock_ws.recv = AsyncMock(return_value=ack)
|
||||||
|
|
||||||
|
async def aiter(items):
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
with patch("unraid_mcp.subscriptions.snapshot.websockets.connect") as mock_connect:
|
||||||
|
mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws)
|
||||||
|
mock_connect.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
result = await subscribe_once("subscription { systemMetricsCpu { percentTotal } }")
|
||||||
|
|
||||||
|
assert result == {"systemMetricsCpu": {"percentTotal": 42.0}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_once_raises_on_graphql_error(mock_ws):
|
||||||
|
"""subscribe_once raises ToolError when server returns GraphQL errors."""
|
||||||
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
from unraid_mcp.subscriptions.snapshot import subscribe_once
|
||||||
|
|
||||||
|
ack = json.dumps({"type": "connection_ack"})
|
||||||
|
error_msg = json.dumps(
|
||||||
|
{
|
||||||
|
"id": "snapshot-1",
|
||||||
|
"type": "next",
|
||||||
|
"payload": {"errors": [{"message": "Not authorized"}]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aiter(items):
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
mock_ws.__aiter__ = lambda s: aiter([error_msg])
|
||||||
|
mock_ws.recv = AsyncMock(return_value=ack)
|
||||||
|
|
||||||
|
with patch("unraid_mcp.subscriptions.snapshot.websockets.connect") as mock_connect:
|
||||||
|
mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws)
|
||||||
|
mock_connect.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with pytest.raises(ToolError, match="Not authorized"):
|
||||||
|
await subscribe_once("subscription { systemMetricsCpu { percentTotal } }")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_collect_returns_multiple_events(mock_ws):
|
||||||
|
"""subscribe_collect returns a list of events received within the window."""
|
||||||
|
from unraid_mcp.subscriptions.snapshot import subscribe_collect
|
||||||
|
|
||||||
|
ack = json.dumps({"type": "connection_ack"})
|
||||||
|
msg1 = _make_ws_message("snapshot-1", {"notificationAdded": {"id": "1", "title": "A"}})
|
||||||
|
msg2 = _make_ws_message("snapshot-1", {"notificationAdded": {"id": "2", "title": "B"}})
|
||||||
|
|
||||||
|
async def aiter(items):
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
await asyncio.sleep(10) # hang after messages
|
||||||
|
|
||||||
|
mock_ws.__aiter__ = lambda s: aiter([msg1, msg2])
|
||||||
|
mock_ws.recv = AsyncMock(return_value=ack)
|
||||||
|
|
||||||
|
with patch("unraid_mcp.subscriptions.snapshot.websockets.connect") as mock_connect:
|
||||||
|
mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws)
|
||||||
|
mock_connect.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
result = await subscribe_collect(
|
||||||
|
"subscription { notificationAdded { id title } }",
|
||||||
|
collect_for=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["notificationAdded"]["id"] == "1"
|
||||||
165
unraid_mcp/subscriptions/snapshot.py
Normal file
165
unraid_mcp/subscriptions/snapshot.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""One-shot GraphQL subscription helpers for MCP tool snapshot actions.
|
||||||
|
|
||||||
|
`subscribe_once(query, variables, timeout)` — connect, subscribe, return the
|
||||||
|
first event's data, then disconnect.
|
||||||
|
|
||||||
|
`subscribe_collect(query, variables, collect_for, timeout)` — connect,
|
||||||
|
subscribe, collect all events for `collect_for` seconds, return the list.
|
||||||
|
|
||||||
|
Neither function maintains a persistent connection — they open and close a
|
||||||
|
WebSocket per call. This is intentional: MCP tools are request-response.
|
||||||
|
Use the SubscriptionManager for long-lived monitoring resources.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
from websockets.typing import Subprotocol
|
||||||
|
|
||||||
|
from ..config.logging import logger
|
||||||
|
from ..config.settings import UNRAID_API_KEY
|
||||||
|
from ..core.exceptions import ToolError
|
||||||
|
from .utils import build_ws_ssl_context, build_ws_url
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_once(
|
||||||
|
query: str,
|
||||||
|
variables: dict[str, Any] | None = None,
|
||||||
|
timeout: float = 10.0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Open a WebSocket subscription, receive the first data event, close, return it.
|
||||||
|
|
||||||
|
Raises ToolError on auth failure, GraphQL errors, or timeout.
|
||||||
|
"""
|
||||||
|
ws_url = build_ws_url()
|
||||||
|
ssl_context = build_ws_ssl_context(ws_url)
|
||||||
|
|
||||||
|
async with websockets.connect(
|
||||||
|
ws_url,
|
||||||
|
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||||
|
open_timeout=timeout,
|
||||||
|
ping_interval=20,
|
||||||
|
ping_timeout=10,
|
||||||
|
ssl=ssl_context,
|
||||||
|
) as ws:
|
||||||
|
proto = ws.subprotocol or "graphql-transport-ws"
|
||||||
|
sub_id = "snapshot-1"
|
||||||
|
|
||||||
|
# Handshake
|
||||||
|
init: dict[str, Any] = {"type": "connection_init"}
|
||||||
|
if UNRAID_API_KEY:
|
||||||
|
init["payload"] = {"headers": {"X-API-Key": UNRAID_API_KEY}}
|
||||||
|
await ws.send(json.dumps(init))
|
||||||
|
|
||||||
|
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||||
|
ack = json.loads(raw)
|
||||||
|
if ack.get("type") == "connection_error":
|
||||||
|
raise ToolError(f"Subscription auth failed: {ack.get('payload')}")
|
||||||
|
if ack.get("type") != "connection_ack":
|
||||||
|
raise ToolError(f"Unexpected handshake response: {ack.get('type')}")
|
||||||
|
|
||||||
|
# Subscribe
|
||||||
|
start_type = "subscribe" if proto == "graphql-transport-ws" else "start"
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"id": sub_id,
|
||||||
|
"type": start_type,
|
||||||
|
"payload": {"query": query, "variables": variables or {}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Await first matching data event
|
||||||
|
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
|
|
||||||
|
async for raw_msg in ws:
|
||||||
|
if asyncio.get_event_loop().time() >= deadline:
|
||||||
|
raise ToolError(f"Subscription timed out after {timeout:.0f}s")
|
||||||
|
|
||||||
|
msg = json.loads(raw_msg)
|
||||||
|
if msg.get("type") == "ping":
|
||||||
|
await ws.send(json.dumps({"type": "pong"}))
|
||||||
|
continue
|
||||||
|
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||||
|
payload = msg.get("payload", {})
|
||||||
|
if errors := payload.get("errors"):
|
||||||
|
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
||||||
|
raise ToolError(f"Subscription errors: {msgs}")
|
||||||
|
if data := payload.get("data"):
|
||||||
|
return data
|
||||||
|
elif msg.get("type") == "error" and msg.get("id") == sub_id:
|
||||||
|
raise ToolError(f"Subscription error: {msg.get('payload')}")
|
||||||
|
|
||||||
|
raise ToolError("WebSocket closed before receiving subscription data")
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_collect(
|
||||||
|
query: str,
|
||||||
|
variables: dict[str, Any] | None = None,
|
||||||
|
collect_for: float = 5.0,
|
||||||
|
timeout: float = 10.0,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Open a subscription, collect events for `collect_for` seconds, close, return list.
|
||||||
|
|
||||||
|
Returns an empty list if no events arrive within the window.
|
||||||
|
Always closes the connection after the window expires.
|
||||||
|
"""
|
||||||
|
ws_url = build_ws_url()
|
||||||
|
ssl_context = build_ws_ssl_context(ws_url)
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async with websockets.connect(
|
||||||
|
ws_url,
|
||||||
|
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||||
|
open_timeout=timeout,
|
||||||
|
ping_interval=20,
|
||||||
|
ping_timeout=10,
|
||||||
|
ssl=ssl_context,
|
||||||
|
) as ws:
|
||||||
|
proto = ws.subprotocol or "graphql-transport-ws"
|
||||||
|
sub_id = "snapshot-1"
|
||||||
|
|
||||||
|
init: dict[str, Any] = {"type": "connection_init"}
|
||||||
|
if UNRAID_API_KEY:
|
||||||
|
init["payload"] = {"headers": {"X-API-Key": UNRAID_API_KEY}}
|
||||||
|
await ws.send(json.dumps(init))
|
||||||
|
|
||||||
|
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||||
|
ack = json.loads(raw)
|
||||||
|
if ack.get("type") not in ("connection_ack",):
|
||||||
|
raise ToolError(f"Subscription handshake failed: {ack.get('type')}")
|
||||||
|
|
||||||
|
start_type = "subscribe" if proto == "graphql-transport-ws" else "start"
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"id": sub_id,
|
||||||
|
"type": start_type,
|
||||||
|
"payload": {"query": query, "variables": variables or {}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||||
|
collect_deadline = asyncio.get_event_loop().time() + collect_for
|
||||||
|
|
||||||
|
async for raw_msg in ws:
|
||||||
|
if asyncio.get_event_loop().time() >= collect_deadline:
|
||||||
|
break
|
||||||
|
msg = json.loads(raw_msg)
|
||||||
|
if msg.get("type") == "ping":
|
||||||
|
await ws.send(json.dumps({"type": "pong"}))
|
||||||
|
continue
|
||||||
|
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||||
|
payload = msg.get("payload", {})
|
||||||
|
if data := payload.get("data"):
|
||||||
|
events.append(data)
|
||||||
|
|
||||||
|
logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window")
|
||||||
|
return events
|
||||||
Reference in New Issue
Block a user