Files
unraid-mcp/unraid_mcp/subscriptions/manager.py
Jacob Magar 1751bc2984 fix: apply all PR review agent findings (silent failures, type safety, test gaps)
Addresses issues found by 4 parallel review agents (code-reviewer,
silent-failure-hunter, type-design-analyzer, pr-test-analyzer).

Source fixes:
- core/utils.py: add public safe_display_url() (moved from tools/health.py)
- core/client.py: rename _redact_sensitive → redact_sensitive (public API)
- core/types.py: add SubscriptionData.__post_init__ for tz-aware datetime
  enforcement; remove 6 unused type aliases (SystemHealth, APIResponse, etc.)
- subscriptions/manager.py: add exc_info=True to both except-Exception blocks;
  add except ValueError break-on-config-error before retry loop; import
  redact_sensitive by new public name
- subscriptions/resources.py: re-raise in autostart_subscriptions() so
  ensure_subscriptions_started() doesn't permanently set _subscriptions_started
- subscriptions/diagnostics.py: except ToolError: raise before broad except;
  use safe_display_url() instead of raw URL slice
- tools/health.py: move _safe_display_url to core/utils; add exc_info=True;
  raise ToolError (not return dict) on ImportError
- tools/info.py: use get_args(INFO_ACTIONS) instead of INFO_ACTIONS.__args__
- tools/{array,docker,keys,notifications,rclone,storage,virtualization}.py:
  add Literal-vs-ALL_ACTIONS sync check at import time

Test fixes:
- test_health.py: import safe_display_url from core.utils; update
  test_diagnose_import_error_internal to expect ToolError (not error dict)
- test_storage.py: add 3 safe_get tests for zero/False/empty-string values
- test_subscription_manager.py: add TestCapLogContentSingleMassiveLine (2 tests)
- test_client.py: rename _redact_sensitive → redact_sensitive; add tests for
  new sensitive keys and is_cacheable explicit-keyword form
2026-02-19 02:23:04 -05:00

593 lines
28 KiB
Python

"""WebSocket subscription manager for real-time Unraid data.
This module manages GraphQL subscriptions over WebSocket connections,
providing real-time data streaming for MCP resources with comprehensive
error handling, reconnection logic, and authentication.
"""
import asyncio
import json
import os
import time
from datetime import UTC, datetime
from typing import Any
import websockets
from websockets.typing import Subprotocol
from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY
from ..core.client import redact_sensitive
from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context, build_ws_url
# Resource data size limits to prevent unbounded memory growth
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1MB
_MAX_RESOURCE_DATA_LINES = 5_000
# Minimum stable connection duration (seconds) before resetting reconnect counter
_STABLE_CONNECTION_SECONDS = 30
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
"""Cap log content in subscription data to prevent unbounded memory growth.
Returns a new dict — does NOT mutate the input. If any nested 'content'
field (from log subscriptions) exceeds the byte limit, truncate it to the
most recent _MAX_RESOURCE_DATA_LINES lines.
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = _cap_log_content(value)
elif (
key == "content"
and isinstance(value, str)
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
):
lines = value.splitlines()
original_line_count = len(lines)
# Keep most recent lines first.
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
# Enforce byte cap while preserving whole-line boundaries where possible.
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
lines = lines[1:]
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
# Last resort: if a single line still exceeds cap, hard-cap bytes.
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
"utf-8", errors="ignore"
)
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
)
result[key] = truncated
else:
result[key] = value
return result
class SubscriptionManager:
"""Manages GraphQL subscriptions and converts them to MCP resources."""
def __init__(self) -> None:
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
self.resource_data: dict[str, SubscriptionData] = {}
self.subscription_lock = asyncio.Lock()
# Configuration
self.auto_start_enabled = (
os.getenv("UNRAID_AUTO_START_SUBSCRIPTIONS", "true").lower() == "true"
)
self.reconnect_attempts: dict[str, int] = {}
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
self.connection_states: dict[str, str] = {} # Track connection state per subscription
self.last_error: dict[str, str] = {} # Track last error per subscription
self._connection_start_times: dict[str, float] = {} # Track when connections started
# Define subscription configurations
self.subscription_configs = {
"logFileSubscription": {
"query": """
subscription LogFileSubscription($path: String!) {
logFile(path: $path) {
path
content
totalLines
}
}
""",
"resource": "unraid://logs/stream",
"description": "Real-time log file streaming",
"auto_start": False, # Started manually with path parameter
}
}
logger.info(
f"[SUBSCRIPTION_MANAGER] Initialized with auto_start={self.auto_start_enabled}, max_reconnects={self.max_reconnect_attempts}"
)
logger.debug(
f"[SUBSCRIPTION_MANAGER] Available subscriptions: {list(self.subscription_configs.keys())}"
)
async def auto_start_all_subscriptions(self) -> None:
"""Auto-start all subscriptions marked for auto-start."""
if not self.auto_start_enabled:
logger.info("[SUBSCRIPTION_MANAGER] Auto-start disabled")
return
logger.info("[SUBSCRIPTION_MANAGER] Starting auto-start process...")
auto_start_count = 0
for subscription_name, config in self.subscription_configs.items():
if config.get("auto_start", False):
try:
logger.info(
f"[SUBSCRIPTION_MANAGER] Auto-starting subscription: {subscription_name}"
)
await self.start_subscription(subscription_name, str(config["query"]))
auto_start_count += 1
except Exception as e:
logger.error(
f"[SUBSCRIPTION_MANAGER] Failed to auto-start {subscription_name}: {e}"
)
self.last_error[subscription_name] = str(e)
logger.info(
f"[SUBSCRIPTION_MANAGER] Auto-start completed. Started {auto_start_count} subscriptions"
)
async def start_subscription(
self, subscription_name: str, query: str, variables: dict[str, Any] | None = None
) -> None:
"""Start a GraphQL subscription and maintain it as a resource."""
logger.info(f"[SUBSCRIPTION:{subscription_name}] Starting subscription...")
if subscription_name in self.active_subscriptions:
logger.warning(
f"[SUBSCRIPTION:{subscription_name}] Subscription already active, skipping"
)
return
# Reset connection tracking
self.reconnect_attempts[subscription_name] = 0
self.connection_states[subscription_name] = "starting"
self._connection_start_times.pop(subscription_name, None)
async with self.subscription_lock:
try:
task = asyncio.create_task(
self._subscription_loop(subscription_name, query, variables or {})
)
self.active_subscriptions[subscription_name] = task
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription task created and started"
)
self.connection_states[subscription_name] = "active"
except Exception as e:
logger.error(
f"[SUBSCRIPTION:{subscription_name}] Failed to start subscription task: {e}"
)
self.connection_states[subscription_name] = "failed"
self.last_error[subscription_name] = str(e)
raise
async def stop_subscription(self, subscription_name: str) -> None:
"""Stop a specific subscription."""
logger.info(f"[SUBSCRIPTION:{subscription_name}] Stopping subscription...")
async with self.subscription_lock:
if subscription_name in self.active_subscriptions:
task = self.active_subscriptions[subscription_name]
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
del self.active_subscriptions[subscription_name]
self.connection_states[subscription_name] = "stopped"
self._connection_start_times.pop(subscription_name, None)
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
async def _subscription_loop(
self, subscription_name: str, query: str, variables: dict[str, Any] | None
) -> None:
"""Main loop for maintaining a GraphQL subscription with comprehensive logging."""
retry_delay: int | float = 5
max_retry_delay = 300 # 5 minutes max
while True:
attempt = self.reconnect_attempts.get(subscription_name, 0) + 1
self.reconnect_attempts[subscription_name] = attempt
logger.info(
f"[WEBSOCKET:{subscription_name}] Connection attempt #{attempt} (max: {self.max_reconnect_attempts})"
)
if attempt > self.max_reconnect_attempts:
logger.error(
f"[WEBSOCKET:{subscription_name}] Max reconnection attempts ({self.max_reconnect_attempts}) exceeded, stopping"
)
self.connection_states[subscription_name] = "max_retries_exceeded"
break
try:
ws_url = build_ws_url()
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
logger.debug(
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}"
)
ssl_context = build_ws_ssl_context(ws_url)
# Connection with timeout
connect_timeout = 10
logger.debug(
f"[WEBSOCKET:{subscription_name}] Connection timeout: {connect_timeout}s"
)
async with websockets.connect(
ws_url,
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
open_timeout=connect_timeout,
ping_interval=20,
ping_timeout=10,
close_timeout=10,
ssl=ssl_context,
) as websocket:
selected_proto = websocket.subprotocol or "none"
logger.info(
f"[WEBSOCKET:{subscription_name}] Connected! Protocol: {selected_proto}"
)
self.connection_states[subscription_name] = "connected"
# Track connection start time — only reset retry counter
# after the connection proves stable (>30s connected)
self._connection_start_times[subscription_name] = time.monotonic()
# Initialize GraphQL-WS protocol
logger.debug(
f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..."
)
init_type = "connection_init"
init_payload: dict[str, Any] = {"type": init_type}
if UNRAID_API_KEY:
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
# Use standard X-API-Key header format (matching HTTP client)
auth_payload = {"headers": {"X-API-Key": UNRAID_API_KEY}}
init_payload["payload"] = auth_payload
else:
logger.warning(
f"[AUTH:{subscription_name}] No API key available for authentication"
)
logger.debug(f"[PROTOCOL:{subscription_name}] Sending connection_init message")
await websocket.send(json.dumps(init_payload))
# Wait for connection acknowledgment
logger.debug(f"[PROTOCOL:{subscription_name}] Waiting for connection_ack...")
init_raw = await asyncio.wait_for(websocket.recv(), timeout=30)
try:
init_data = json.loads(init_raw)
logger.debug(
f"[PROTOCOL:{subscription_name}] Received init response: {init_data.get('type')}"
)
except json.JSONDecodeError as e:
init_preview = (
init_raw[:200]
if isinstance(init_raw, str)
else init_raw[:200].decode("utf-8", errors="replace")
)
logger.error(
f"[PROTOCOL:{subscription_name}] Failed to decode init response: {init_preview}..."
)
self.last_error[subscription_name] = f"Invalid JSON in init response: {e}"
break
# Handle connection acknowledgment
if init_data.get("type") == "connection_ack":
logger.info(
f"[PROTOCOL:{subscription_name}] Connection acknowledged successfully"
)
self.connection_states[subscription_name] = "authenticated"
elif init_data.get("type") == "connection_error":
error_payload = init_data.get("payload", {})
logger.error(
f"[AUTH:{subscription_name}] Authentication failed: {error_payload}"
)
self.last_error[subscription_name] = (
f"Authentication error: {error_payload}"
)
self.connection_states[subscription_name] = "auth_failed"
break
else:
logger.warning(
f"[PROTOCOL:{subscription_name}] Unexpected init response: {init_data}"
)
# Continue anyway - some servers send other messages first
# Start the subscription
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Starting GraphQL subscription..."
)
start_type = (
"subscribe" if selected_proto == "graphql-transport-ws" else "start"
)
subscription_message = {
"id": subscription_name,
"type": start_type,
"payload": {"query": query, "variables": variables},
}
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
)
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}"
)
await websocket.send(json.dumps(subscription_message))
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription started successfully"
)
self.connection_states[subscription_name] = "subscribed"
# Listen for subscription data
message_count = 0
async for message in websocket:
try:
data = json.loads(message)
message_count += 1
message_type = data.get("type", "unknown")
logger.debug(
f"[DATA:{subscription_name}] Message #{message_count}: {message_type}"
)
# Handle different message types
expected_data_type = (
"next" if selected_proto == "graphql-transport-ws" else "data"
)
if (
data.get("type") == expected_data_type
and data.get("id") == subscription_name
):
payload = data.get("payload", {})
if payload.get("data"):
logger.info(
f"[DATA:{subscription_name}] Received subscription data update"
)
capped_data = (
_cap_log_content(payload["data"])
if isinstance(payload["data"], dict)
else payload["data"]
)
new_entry = SubscriptionData(
data=capped_data,
last_updated=datetime.now(UTC),
subscription_type=subscription_name,
)
async with self.subscription_lock:
self.resource_data[subscription_name] = new_entry
logger.debug(
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
)
elif payload.get("errors"):
logger.error(
f"[DATA:{subscription_name}] GraphQL errors in response: {payload['errors']}"
)
self.last_error[subscription_name] = (
f"GraphQL errors: {payload['errors']}"
)
else:
logger.warning(
f"[DATA:{subscription_name}] Empty or invalid data payload: {payload}"
)
elif data.get("type") == "ping":
logger.debug(
f"[PROTOCOL:{subscription_name}] Received ping, sending pong"
)
await websocket.send(json.dumps({"type": "pong"}))
elif data.get("type") == "error":
error_payload = data.get("payload", {})
logger.error(
f"[SUBSCRIPTION:{subscription_name}] Subscription error: {error_payload}"
)
self.last_error[subscription_name] = (
f"Subscription error: {error_payload}"
)
self.connection_states[subscription_name] = "error"
elif data.get("type") == "complete":
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription completed by server"
)
self.connection_states[subscription_name] = "completed"
break
elif data.get("type") in ["ka", "ping", "pong"]:
logger.debug(
f"[PROTOCOL:{subscription_name}] Keepalive message: {message_type}"
)
else:
logger.debug(
f"[PROTOCOL:{subscription_name}] Unhandled message type: {message_type}"
)
except json.JSONDecodeError as e:
msg_preview = (
message[:200]
if isinstance(message, str)
else message[:200].decode("utf-8", errors="replace")
)
logger.error(
f"[PROTOCOL:{subscription_name}] Failed to decode message: {msg_preview}..."
)
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e:
logger.error(
f"[DATA:{subscription_name}] Error processing message: {e}",
exc_info=True,
)
msg_preview = (
message[:200]
if isinstance(message, str)
else message[:200].decode("utf-8", errors="replace")
)
logger.debug(
f"[DATA:{subscription_name}] Raw message: {msg_preview}..."
)
except TimeoutError:
error_msg = "Connection or authentication timeout"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "timeout"
except websockets.exceptions.ConnectionClosed as e:
error_msg = f"WebSocket connection closed: {e}"
logger.warning(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "disconnected"
except websockets.exceptions.InvalidURI as e:
error_msg = f"Invalid WebSocket URI: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "invalid_uri"
break # Don't retry on invalid URI
except ValueError as e:
# Non-retryable configuration error (e.g. UNRAID_API_URL not set)
error_msg = f"Configuration error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
break # Don't retry on configuration errors
except Exception as e:
error_msg = f"Unexpected error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True)
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
# Check if connection was stable before deciding on retry behavior
start_time = self._connection_start_times.pop(subscription_name, None)
if start_time is not None:
connected_duration = time.monotonic() - start_time
if connected_duration >= _STABLE_CONNECTION_SECONDS:
# Connection was stable — reset retry counter and backoff
logger.info(
f"[WEBSOCKET:{subscription_name}] Connection was stable "
f"({connected_duration:.0f}s >= {_STABLE_CONNECTION_SECONDS}s), "
f"resetting retry counter"
)
self.reconnect_attempts[subscription_name] = 0
retry_delay = 5
else:
logger.warning(
f"[WEBSOCKET:{subscription_name}] Connection was unstable "
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
)
# Calculate backoff delay
retry_delay = min(retry_delay * 1.5, max_retry_delay)
logger.info(
f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds..."
)
self.connection_states[subscription_name] = "reconnecting"
await asyncio.sleep(retry_delay)
# The while loop exited (via break or max_retries exceeded).
# Remove from active_subscriptions so start_subscription() can restart it.
async with self.subscription_lock:
self.active_subscriptions.pop(subscription_name, None)
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
f"removed from active_subscriptions. Final state: "
f"{self.connection_states.get(subscription_name, 'unknown')}"
)
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
"""Get current resource data with enhanced logging."""
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
async with self.subscription_lock:
if resource_name in self.resource_data:
data = self.resource_data[resource_name]
age_seconds = (datetime.now(UTC) - data.last_updated).total_seconds()
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
return data.data
logger.debug(f"[RESOURCE:{resource_name}] No data available")
return None
def list_active_subscriptions(self) -> list[str]:
"""List all active subscriptions."""
active = list(self.active_subscriptions.keys())
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
return active
async def get_subscription_status(self) -> dict[str, dict[str, Any]]:
"""Get detailed status of all subscriptions for diagnostics."""
status = {}
async with self.subscription_lock:
for sub_name, config in self.subscription_configs.items():
sub_status = {
"config": {
"resource": config["resource"],
"description": config["description"],
"auto_start": config.get("auto_start", False),
},
"runtime": {
"active": sub_name in self.active_subscriptions,
"connection_state": self.connection_states.get(sub_name, "not_started"),
"reconnect_attempts": self.reconnect_attempts.get(sub_name, 0),
"last_error": self.last_error.get(sub_name, None),
},
}
# Add data info if available
if sub_name in self.resource_data:
data_info = self.resource_data[sub_name]
age_seconds = (datetime.now(UTC) - data_info.last_updated).total_seconds()
sub_status["data"] = {
"available": True,
"last_updated": data_info.last_updated.isoformat(),
"age_seconds": age_seconds,
}
else:
sub_status["data"] = {"available": False}
status[sub_name] = sub_status
logger.debug(f"[SUBSCRIPTION_MANAGER] Generated status for {len(status)} subscriptions")
return status
# Global subscription manager instance
subscription_manager = SubscriptionManager()