Files
unraid-mcp/unraid_mcp/subscriptions/manager.py
Jacob Magar 37e9424a5c fix: address 54 MEDIUM/LOW priority PR review issues
Comprehensive fixes across Python code, shell scripts, and documentation
addressing all remaining MEDIUM and LOW priority review comments.

Python Code Fixes (27 fixes):
- tools/info.py: Simplified dispatch with lookup tables, defensive guards,
  CPU fallback formatting, !s conversion flags, module-level sync assertion
- tools/docker.py: Case-insensitive container ID regex, keyword-only confirm,
  module-level ALL_ACTIONS constant
- tools/virtualization.py: Normalized single-VM dict responses, unified
  list/details queries
- core/client.py: Fixed HTTP client singleton race condition, compound key
  substring matching for sensitive data redaction
- subscriptions/: Extracted SSL context creation to shared helper in utils.py,
  replaced deprecated ssl._create_unverified_context API
- tools/array.py: Renamed parity_history to parity_status, hoisted ALL_ACTIONS
- tools/storage.py: Fixed dict(None) risks, temperature 0 falsiness bug
- tools/notifications.py, keys.py, rclone.py: Fixed dict(None) TypeError risks
- tests/: Fixed generator type annotations, added coverage for compound keys

Shell Script Fixes (13 fixes):
- dashboard.sh: Dynamic server discovery, conditional debug output, null-safe
  jq, notification count guard order, removed unused variables
- unraid-query.sh: Proper JSON escaping via jq, --ignore-errors and --insecure
  CLI flags, TLS verification now on by default
- validate-marketplace.sh: Removed unused YELLOW variable, defensive jq,
  simplified repository URL output

Documentation Fixes (24+ fixes):
- Version consistency: Updated all references to v0.2.0 across pyproject.toml,
  plugin.json, marketplace.json, MARKETPLACE.md, __init__.py, README files
- Tool count updates: Changed all "26 tools" references to "10 tools, 90 actions"
- Markdown lint: Fixed MD022, MD031, MD047 issues across multiple files
- Research docs: Fixed auth headers, removed web artifacts, corrected stale info
- Skills docs: Fixed query examples, endpoint counts, env var references

All 227 tests pass, ruff and ty checks clean.
2026-02-15 17:09:31 -05:00

396 lines
20 KiB
Python

"""WebSocket subscription manager for real-time Unraid data.
This module manages GraphQL subscriptions over WebSocket connections,
providing real-time data streaming for MCP resources with comprehensive
error handling, reconnection logic, and authentication.
"""
import asyncio
import json
import os
from datetime import datetime
from typing import Any
import websockets
from websockets.typing import Subprotocol
from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context
class SubscriptionManager:
"""Manages GraphQL subscriptions and converts them to MCP resources."""
def __init__(self) -> None:
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
self.resource_data: dict[str, SubscriptionData] = {}
self.websocket: websockets.WebSocketServerProtocol | None = None
self.subscription_lock = asyncio.Lock()
# Configuration
self.auto_start_enabled = os.getenv("UNRAID_AUTO_START_SUBSCRIPTIONS", "true").lower() == "true"
self.reconnect_attempts: dict[str, int] = {}
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
self.connection_states: dict[str, str] = {} # Track connection state per subscription
self.last_error: dict[str, str] = {} # Track last error per subscription
# Define subscription configurations
self.subscription_configs = {
"logFileSubscription": {
"query": """
subscription LogFileSubscription($path: String!) {
logFile(path: $path) {
path
content
totalLines
}
}
""",
"resource": "unraid://logs/stream",
"description": "Real-time log file streaming",
"auto_start": False # Started manually with path parameter
}
}
logger.info(f"[SUBSCRIPTION_MANAGER] Initialized with auto_start={self.auto_start_enabled}, max_reconnects={self.max_reconnect_attempts}")
logger.debug(f"[SUBSCRIPTION_MANAGER] Available subscriptions: {list(self.subscription_configs.keys())}")
async def auto_start_all_subscriptions(self) -> None:
"""Auto-start all subscriptions marked for auto-start."""
if not self.auto_start_enabled:
logger.info("[SUBSCRIPTION_MANAGER] Auto-start disabled")
return
logger.info("[SUBSCRIPTION_MANAGER] Starting auto-start process...")
auto_start_count = 0
for subscription_name, config in self.subscription_configs.items():
if config.get("auto_start", False):
try:
logger.info(f"[SUBSCRIPTION_MANAGER] Auto-starting subscription: {subscription_name}")
await self.start_subscription(subscription_name, str(config["query"]))
auto_start_count += 1
except Exception as e:
logger.error(f"[SUBSCRIPTION_MANAGER] Failed to auto-start {subscription_name}: {e}")
self.last_error[subscription_name] = str(e)
logger.info(f"[SUBSCRIPTION_MANAGER] Auto-start completed. Started {auto_start_count} subscriptions")
async def start_subscription(self, subscription_name: str, query: str, variables: dict[str, Any] | None = None) -> None:
"""Start a GraphQL subscription and maintain it as a resource."""
logger.info(f"[SUBSCRIPTION:{subscription_name}] Starting subscription...")
if subscription_name in self.active_subscriptions:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] Subscription already active, skipping")
return
# Reset connection tracking
self.reconnect_attempts[subscription_name] = 0
self.connection_states[subscription_name] = "starting"
async with self.subscription_lock:
try:
task = asyncio.create_task(self._subscription_loop(subscription_name, query, variables or {}))
self.active_subscriptions[subscription_name] = task
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription task created and started")
self.connection_states[subscription_name] = "active"
except Exception as e:
logger.error(f"[SUBSCRIPTION:{subscription_name}] Failed to start subscription task: {e}")
self.connection_states[subscription_name] = "failed"
self.last_error[subscription_name] = str(e)
raise
async def stop_subscription(self, subscription_name: str) -> None:
"""Stop a specific subscription."""
logger.info(f"[SUBSCRIPTION:{subscription_name}] Stopping subscription...")
async with self.subscription_lock:
if subscription_name in self.active_subscriptions:
task = self.active_subscriptions[subscription_name]
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
del self.active_subscriptions[subscription_name]
self.connection_states[subscription_name] = "stopped"
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
async def _subscription_loop(self, subscription_name: str, query: str, variables: dict[str, Any] | None) -> None:
"""Main loop for maintaining a GraphQL subscription with comprehensive logging."""
retry_delay: int | float = 5
max_retry_delay = 300 # 5 minutes max
while True:
attempt = self.reconnect_attempts.get(subscription_name, 0) + 1
self.reconnect_attempts[subscription_name] = attempt
logger.info(f"[WEBSOCKET:{subscription_name}] Connection attempt #{attempt} (max: {self.max_reconnect_attempts})")
if attempt > self.max_reconnect_attempts:
logger.error(f"[WEBSOCKET:{subscription_name}] Max reconnection attempts ({self.max_reconnect_attempts}) exceeded, stopping")
self.connection_states[subscription_name] = "max_retries_exceeded"
break
try:
# Build WebSocket URL with detailed logging
if not UNRAID_API_URL:
raise ValueError("UNRAID_API_URL is not configured")
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://"):]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://"):]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
logger.debug(f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}")
ssl_context = build_ws_ssl_context(ws_url)
# Connection with timeout
connect_timeout = 10
logger.debug(f"[WEBSOCKET:{subscription_name}] Connection timeout: {connect_timeout}s")
async with websockets.connect(
ws_url,
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
ping_interval=20,
ping_timeout=10,
close_timeout=10,
ssl=ssl_context
) as websocket:
selected_proto = websocket.subprotocol or "none"
logger.info(f"[WEBSOCKET:{subscription_name}] Connected! Protocol: {selected_proto}")
self.connection_states[subscription_name] = "connected"
# Reset retry count on successful connection
self.reconnect_attempts[subscription_name] = 0
retry_delay = 5 # Reset delay
# Initialize GraphQL-WS protocol
logger.debug(f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol...")
init_type = "connection_init"
init_payload: dict[str, Any] = {"type": init_type}
if UNRAID_API_KEY:
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
# Use standard X-API-Key header format (matching HTTP client)
auth_payload = {
"headers": {
"X-API-Key": UNRAID_API_KEY
}
}
init_payload["payload"] = auth_payload
else:
logger.warning(f"[AUTH:{subscription_name}] No API key available for authentication")
logger.debug(f"[PROTOCOL:{subscription_name}] Sending connection_init message")
await websocket.send(json.dumps(init_payload))
# Wait for connection acknowledgment
logger.debug(f"[PROTOCOL:{subscription_name}] Waiting for connection_ack...")
init_raw = await asyncio.wait_for(websocket.recv(), timeout=30)
try:
init_data = json.loads(init_raw)
logger.debug(f"[PROTOCOL:{subscription_name}] Received init response: {init_data.get('type')}")
except json.JSONDecodeError as e:
init_preview = init_raw[:200] if isinstance(init_raw, str) else init_raw[:200].decode("utf-8", errors="replace")
logger.error(f"[PROTOCOL:{subscription_name}] Failed to decode init response: {init_preview}...")
self.last_error[subscription_name] = f"Invalid JSON in init response: {e}"
break
# Handle connection acknowledgment
if init_data.get("type") == "connection_ack":
logger.info(f"[PROTOCOL:{subscription_name}] Connection acknowledged successfully")
self.connection_states[subscription_name] = "authenticated"
elif init_data.get("type") == "connection_error":
error_payload = init_data.get("payload", {})
logger.error(f"[AUTH:{subscription_name}] Authentication failed: {error_payload}")
self.last_error[subscription_name] = f"Authentication error: {error_payload}"
self.connection_states[subscription_name] = "auth_failed"
break
else:
logger.warning(f"[PROTOCOL:{subscription_name}] Unexpected init response: {init_data}")
# Continue anyway - some servers send other messages first
# Start the subscription
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Starting GraphQL subscription...")
start_type = "subscribe" if selected_proto == "graphql-transport-ws" else "start"
subscription_message = {
"id": subscription_name,
"type": start_type,
"payload": {
"query": query,
"variables": variables
}
}
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}")
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Variables: {variables}")
await websocket.send(json.dumps(subscription_message))
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription started successfully")
self.connection_states[subscription_name] = "subscribed"
# Listen for subscription data
message_count = 0
async for message in websocket:
try:
data = json.loads(message)
message_count += 1
message_type = data.get("type", "unknown")
logger.debug(f"[DATA:{subscription_name}] Message #{message_count}: {message_type}")
# Handle different message types
expected_data_type = "next" if selected_proto == "graphql-transport-ws" else "data"
if data.get("type") == expected_data_type and data.get("id") == subscription_name:
payload = data.get("payload", {})
if payload.get("data"):
logger.info(f"[DATA:{subscription_name}] Received subscription data update")
self.resource_data[subscription_name] = SubscriptionData(
data=payload["data"],
last_updated=datetime.now(),
subscription_type=subscription_name
)
logger.debug(f"[RESOURCE:{subscription_name}] Resource data updated successfully")
elif payload.get("errors"):
logger.error(f"[DATA:{subscription_name}] GraphQL errors in response: {payload['errors']}")
self.last_error[subscription_name] = f"GraphQL errors: {payload['errors']}"
else:
logger.warning(f"[DATA:{subscription_name}] Empty or invalid data payload: {payload}")
elif data.get("type") == "ping":
logger.debug(f"[PROTOCOL:{subscription_name}] Received ping, sending pong")
await websocket.send(json.dumps({"type": "pong"}))
elif data.get("type") == "error":
error_payload = data.get("payload", {})
logger.error(f"[SUBSCRIPTION:{subscription_name}] Subscription error: {error_payload}")
self.last_error[subscription_name] = f"Subscription error: {error_payload}"
self.connection_states[subscription_name] = "error"
elif data.get("type") == "complete":
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription completed by server")
self.connection_states[subscription_name] = "completed"
break
elif data.get("type") in ["ka", "ping", "pong"]:
logger.debug(f"[PROTOCOL:{subscription_name}] Keepalive message: {message_type}")
else:
logger.debug(f"[PROTOCOL:{subscription_name}] Unhandled message type: {message_type}")
except json.JSONDecodeError as e:
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode("utf-8", errors="replace")
logger.error(f"[PROTOCOL:{subscription_name}] Failed to decode message: {msg_preview}...")
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e:
logger.error(f"[DATA:{subscription_name}] Error processing message: {e}")
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode("utf-8", errors="replace")
logger.debug(f"[DATA:{subscription_name}] Raw message: {msg_preview}...")
except TimeoutError:
error_msg = "Connection or authentication timeout"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "timeout"
except websockets.exceptions.ConnectionClosed as e:
error_msg = f"WebSocket connection closed: {e}"
logger.warning(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "disconnected"
except websockets.exceptions.InvalidURI as e:
error_msg = f"Invalid WebSocket URI: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "invalid_uri"
break # Don't retry on invalid URI
except Exception as e:
error_msg = f"Unexpected error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
# Calculate backoff delay
retry_delay = min(retry_delay * 1.5, max_retry_delay)
logger.info(f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds...")
self.connection_states[subscription_name] = "reconnecting"
await asyncio.sleep(retry_delay)
def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
"""Get current resource data with enhanced logging."""
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
if resource_name in self.resource_data:
data = self.resource_data[resource_name]
age_seconds = (datetime.now() - data.last_updated).total_seconds()
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
return data.data
logger.debug(f"[RESOURCE:{resource_name}] No data available")
return None
def list_active_subscriptions(self) -> list[str]:
"""List all active subscriptions."""
active = list(self.active_subscriptions.keys())
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
return active
def get_subscription_status(self) -> dict[str, dict[str, Any]]:
"""Get detailed status of all subscriptions for diagnostics."""
status = {}
for sub_name, config in self.subscription_configs.items():
sub_status = {
"config": {
"resource": config["resource"],
"description": config["description"],
"auto_start": config.get("auto_start", False)
},
"runtime": {
"active": sub_name in self.active_subscriptions,
"connection_state": self.connection_states.get(sub_name, "not_started"),
"reconnect_attempts": self.reconnect_attempts.get(sub_name, 0),
"last_error": self.last_error.get(sub_name, None)
}
}
# Add data info if available
if sub_name in self.resource_data:
data_info = self.resource_data[sub_name]
age_seconds = (datetime.now() - data_info.last_updated).total_seconds()
sub_status["data"] = {
"available": True,
"last_updated": data_info.last_updated.isoformat(),
"age_seconds": age_seconds
}
else:
sub_status["data"] = {"available": False}
status[sub_name] = sub_status
logger.debug(f"[SUBSCRIPTION_MANAGER] Generated status for {len(status)} subscriptions")
return status
# Global subscription manager instance
subscription_manager = SubscriptionManager()