diff --git a/unraid_mcp/subscriptions/diagnostics.py b/unraid_mcp/subscriptions/diagnostics.py index ea77e69..ed9925b 100644 --- a/unraid_mcp/subscriptions/diagnostics.py +++ b/unraid_mcp/subscriptions/diagnostics.py @@ -7,6 +7,7 @@ development and debugging purposes. import asyncio import json +import re from datetime import datetime from typing import Any @@ -19,7 +20,63 @@ from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL from ..core.exceptions import ToolError from .manager import subscription_manager from .resources import ensure_subscriptions_started -from .utils import build_ws_ssl_context +from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url + + +# Schema field names that appear inside the selection set of allowed subscriptions. +# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the +# opening "{", so we list the actual field names used in queries (e.g. "logFile"), +# NOT the operation-level names (e.g. "logFileSubscription"). +_ALLOWED_SUBSCRIPTION_FIELDS = frozenset( + { + "logFile", + "containerStats", + "cpu", + "memory", + "array", + "network", + "docker", + "vm", + } +) + +# Pattern: must start with "subscription" keyword, then extract the first selected +# field name (the word immediately after "{"). +_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE) +# Reject any query that contains a bare "mutation" or "query" operation keyword. +_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE) + + +def _validate_subscription_query(query: str) -> str: + """Validate that a subscription query is safe to execute. + + Only allows subscription operations targeting whitelisted schema field names. + Rejects any query containing mutation/query keywords. + + Returns: + The extracted field name (e.g. "logFile"). + + Raises: + ToolError: If the query fails validation. + """ + if _FORBIDDEN_KEYWORDS.search(query): + raise ToolError("Query rejected: must be a subscription, not a mutation or query.") + + match = _SUBSCRIPTION_NAME_PATTERN.match(query) + if not match: + raise ToolError( + "Query rejected: must start with 'subscription' and contain a valid " + 'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }' + ) + + field_name = match.group(1) + if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS: + raise ToolError( + f"Subscription field '{field_name}' is not allowed. " + f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}" + ) + + return field_name def register_diagnostic_tools(mcp: FastMCP) -> None: @@ -34,6 +91,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: """Test a GraphQL subscription query directly to debug schema issues. Use this to find working subscription field names and structure. + Only whitelisted schema fields are permitted (logFile, containerStats, + cpu, memory, array, network, docker, vm). Args: subscription_query: The GraphQL subscription query to test @@ -41,16 +100,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: Returns: Dict containing test results and response data """ - try: - logger.info(f"[TEST_SUBSCRIPTION] Testing query: {subscription_query}") + # Validate before any network I/O (Bug 1 fix) + field_name = _validate_subscription_query(subscription_query) - # Build WebSocket URL - if not UNRAID_API_URL: - raise ToolError("UNRAID_API_URL is not configured") - ws_url = ( - UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://") - + "/graphql" - ) + try: + logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'") + + # Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix) + try: + ws_url = build_ws_url() + except ValueError as e: + raise ToolError(str(e)) from e ssl_context = build_ws_ssl_context(ws_url) @@ -59,6 +119,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: ws_url, subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")], ssl=ssl_context, + open_timeout=10, ping_interval=30, ping_timeout=10, ) as websocket: @@ -102,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: "note": "Connection successful, subscription may be waiting for events", } + except ToolError: + raise except Exception as e: logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True) return {"error": str(e), "query_tested": subscription_query} @@ -124,8 +187,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: # Get comprehensive status status = subscription_manager.get_subscription_status() - # Initialize connection issues list with proper type - connection_issues: list[dict[str, Any]] = [] + # Analyze connection issues and error counts via shared helper. + # Gates connection_issues on current failure state (Bug 5 fix). + error_count, connection_issues = _analyze_subscription_status(status) + + # Calculate WebSocket URL + ws_url_display: str | None = None + if UNRAID_API_URL: + try: + ws_url_display = build_ws_url() + except ValueError: + ws_url_display = None # Add environment info with explicit typing diagnostic_info: dict[str, Any] = { @@ -135,7 +207,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: "max_reconnect_attempts": subscription_manager.max_reconnect_attempts, "unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None, "api_key_configured": bool(UNRAID_API_KEY), - "websocket_url": None, + "websocket_url": ws_url_display, }, "subscriptions": status, "summary": { @@ -147,40 +219,11 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: ), "active_count": len(subscription_manager.active_subscriptions), "with_data": len(subscription_manager.resource_data), - "in_error_state": 0, + "in_error_state": error_count, "connection_issues": connection_issues, }, } - # Calculate WebSocket URL - if UNRAID_API_URL: - if UNRAID_API_URL.startswith("https://"): - ws_url = "wss://" + UNRAID_API_URL[len("https://") :] - elif UNRAID_API_URL.startswith("http://"): - ws_url = "ws://" + UNRAID_API_URL[len("http://") :] - else: - ws_url = UNRAID_API_URL - if not ws_url.endswith("/graphql"): - ws_url = ws_url.rstrip("/") + "/graphql" - diagnostic_info["environment"]["websocket_url"] = ws_url - - # Analyze issues - for sub_name, sub_status in status.items(): - runtime = sub_status.get("runtime", {}) - connection_state = runtime.get("connection_state", "unknown") - - if connection_state in ["error", "auth_failed", "timeout", "max_retries_exceeded"]: - diagnostic_info["summary"]["in_error_state"] += 1 - - if runtime.get("last_error"): - connection_issues.append( - { - "subscription": sub_name, - "state": connection_state, - "error": runtime["last_error"], - } - ) - # Add troubleshooting recommendations recommendations: list[str] = [] diff --git a/unraid_mcp/subscriptions/manager.py b/unraid_mcp/subscriptions/manager.py index c98be94..69cf6ee 100644 --- a/unraid_mcp/subscriptions/manager.py +++ b/unraid_mcp/subscriptions/manager.py @@ -20,6 +20,57 @@ from ..core.types import SubscriptionData from .utils import build_ws_ssl_context +# Resource data size limits to prevent unbounded memory growth +_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1 MB +_MAX_RESOURCE_DATA_LINES = 5_000 + + +def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: + """Cap log content in subscription data to prevent unbounded memory growth. + + Returns a new dict — does NOT mutate the input. If any nested 'content' + field (from log subscriptions) exceeds the byte limit, truncate it to the + most recent _MAX_RESOURCE_DATA_LINES lines. + + The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES. + """ + result: dict[str, Any] = {} + for key, value in data.items(): + if isinstance(value, dict): + result[key] = _cap_log_content(value) + elif ( + key == "content" + and isinstance(value, str) + # Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap + and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES + ): + lines = value.splitlines() + original_line_count = len(lines) + + # Keep most recent lines first. + if len(lines) > _MAX_RESOURCE_DATA_LINES: + lines = lines[-_MAX_RESOURCE_DATA_LINES:] + + truncated = "\n".join(lines) + # Encode once and slice bytes instead of an O(n²) line-trim loop + encoded = truncated.encode("utf-8", errors="replace") + if len(encoded) > _MAX_RESOURCE_DATA_BYTES: + truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore") + # Strip partial first line that may have been cut mid-character + nl_pos = truncated.find("\n") + if nl_pos != -1: + truncated = truncated[nl_pos + 1 :] + + logger.warning( + f"[RESOURCE] Capped log content from {original_line_count} to " + f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)" + ) + result[key] = truncated + else: + result[key] = value + return result + + class SubscriptionManager: """Manages GraphQL subscriptions and converts them to MCP resources.""" diff --git a/unraid_mcp/subscriptions/resources.py b/unraid_mcp/subscriptions/resources.py index f1b4caf..58f7b6f 100644 --- a/unraid_mcp/subscriptions/resources.py +++ b/unraid_mcp/subscriptions/resources.py @@ -4,8 +4,10 @@ This module defines MCP resources that bridge between the subscription manager and the MCP protocol, providing fallback queries when subscription data is unavailable. """ +import asyncio import json import os +from typing import Final import anyio from fastmcp import FastMCP @@ -16,22 +18,29 @@ from .manager import subscription_manager # Global flag to track subscription startup _subscriptions_started = False +_startup_lock: Final[asyncio.Lock] = asyncio.Lock() async def ensure_subscriptions_started() -> None: """Ensure subscriptions are started, called from async context.""" global _subscriptions_started + # Fast-path: skip lock if already started if _subscriptions_started: return - logger.info("[STARTUP] First async operation detected, starting subscriptions...") - try: - await autostart_subscriptions() - _subscriptions_started = True - logger.info("[STARTUP] Subscriptions started successfully") - except Exception as e: - logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True) + # Slow-path: acquire lock for initialization (double-checked locking) + async with _startup_lock: + if _subscriptions_started: + return + + logger.info("[STARTUP] First async operation detected, starting subscriptions...") + try: + await autostart_subscriptions() + _subscriptions_started = True + logger.info("[STARTUP] Subscriptions started successfully") + except Exception as e: + logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True) async def autostart_subscriptions() -> None: @@ -39,11 +48,12 @@ async def autostart_subscriptions() -> None: logger.info("[AUTOSTART] Initiating subscription auto-start process...") try: - # Use the new SubscriptionManager auto-start method + # Use the SubscriptionManager auto-start method await subscription_manager.auto_start_all_subscriptions() logger.info("[AUTOSTART] Auto-start process completed successfully") except Exception as e: logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True) + raise # Propagate so ensure_subscriptions_started doesn't mark as started # Optional log file subscription log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH") diff --git a/unraid_mcp/subscriptions/utils.py b/unraid_mcp/subscriptions/utils.py index 63674a3..7ca8fc6 100644 --- a/unraid_mcp/subscriptions/utils.py +++ b/unraid_mcp/subscriptions/utils.py @@ -1,8 +1,79 @@ """Shared utilities for the subscription system.""" import ssl as _ssl +from typing import Any -from ..config.settings import UNRAID_VERIFY_SSL +from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL + + +def build_ws_url() -> str: + """Build a WebSocket URL from the configured UNRAID_API_URL. + + Converts http(s) scheme to ws(s) and ensures /graphql path suffix. + + Returns: + The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql"). + + Raises: + ValueError: If UNRAID_API_URL is not configured or has an unrecognised scheme. + """ + if not UNRAID_API_URL: + raise ValueError("UNRAID_API_URL is not configured") + + if UNRAID_API_URL.startswith("https://"): + ws_url = "wss://" + UNRAID_API_URL[len("https://") :] + elif UNRAID_API_URL.startswith("http://"): + ws_url = "ws://" + UNRAID_API_URL[len("http://") :] + elif UNRAID_API_URL.startswith(("ws://", "wss://")): + ws_url = UNRAID_API_URL # Already a WebSocket URL + else: + raise ValueError( + f"UNRAID_API_URL must start with http://, https://, ws://, or wss://. " + f"Got: {UNRAID_API_URL[:20]}..." + ) + + if not ws_url.endswith("/graphql"): + ws_url = ws_url.rstrip("/") + "/graphql" + + return ws_url + + +def _analyze_subscription_status( + status: dict[str, Any], +) -> tuple[int, list[dict[str, Any]]]: + """Analyze subscription status dict, returning error count and connection issues. + + Only reports connection_issues for subscriptions that are currently in a + failure state (not recovered ones that happen to have a stale last_error). + + Args: + status: Dict of subscription name -> status info from get_subscription_status(). + + Returns: + Tuple of (error_count, connection_issues_list). + """ + _error_states = frozenset( + {"error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"} + ) + error_count = 0 + connection_issues: list[dict[str, Any]] = [] + + for sub_name, sub_status in status.items(): + runtime = sub_status.get("runtime", {}) + conn_state = runtime.get("connection_state", "unknown") + if conn_state in _error_states: + error_count += 1 + # Gate on current failure state so recovered subscriptions are not reported + if runtime.get("last_error") and conn_state in _error_states: + connection_issues.append( + { + "subscription": sub_name, + "state": conn_state, + "error": runtime["last_error"], + } + ) + + return error_count, connection_issues def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None: