diff --git a/unraid_mcp/subscriptions/diagnostics.py b/unraid_mcp/subscriptions/diagnostics.py index 5862455..4207aa4 100644 --- a/unraid_mcp/subscriptions/diagnostics.py +++ b/unraid_mcp/subscriptions/diagnostics.py @@ -6,7 +6,6 @@ development and debugging purposes. """ import asyncio -import contextlib import json import re from datetime import UTC, datetime @@ -25,35 +24,38 @@ from .resources import ensure_subscriptions_started from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url -_ALLOWED_SUBSCRIPTION_NAMES = frozenset( +# Schema field names that appear inside the selection set of allowed subscriptions. +# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the +# opening "{", so we list the actual field names used in queries (e.g. "logFile"), +# NOT the operation-level names (e.g. "logFileSubscription"). +_ALLOWED_SUBSCRIPTION_FIELDS = frozenset( { - "logFileSubscription", - "containerStatsSubscription", - "cpuSubscription", - "memorySubscription", - "arraySubscription", - "networkSubscription", - "dockerSubscription", - "vmSubscription", + "logFile", + "containerStats", + "cpu", + "memory", + "array", + "network", + "docker", + "vm", } ) -# Pattern: must start with "subscription" and contain only a known subscription name. -# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query" -# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are -# not rejected — only bare "mutation" or "query" operation keywords are blocked. +# Pattern: must start with "subscription" keyword, then extract the first selected +# field name (the word immediately after "{"). _SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE) +# Reject any query that contains a bare "mutation" or "query" operation keyword. _FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE) def _validate_subscription_query(query: str) -> str: """Validate that a subscription query is safe to execute. - Only allows subscription operations targeting whitelisted subscription names. + Only allows subscription operations targeting whitelisted schema field names. Rejects any query containing mutation/query keywords. Returns: - The extracted subscription name. + The extracted field name (e.g. "logFile"). Raises: ToolError: If the query fails validation. @@ -65,17 +67,17 @@ def _validate_subscription_query(query: str) -> str: if not match: raise ToolError( "Query rejected: must start with 'subscription' and contain a valid " - "subscription operation. Example: subscription { logFileSubscription { ... } }" + 'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }' ) - sub_name = match.group(1) - if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES: + field_name = match.group(1) + if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS: raise ToolError( - f"Subscription '{sub_name}' is not allowed. " - f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}" + f"Subscription field '{field_name}' is not allowed. " + f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}" ) - return sub_name + return field_name def register_diagnostic_tools(mcp: FastMCP) -> None: @@ -90,10 +92,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: """Test a GraphQL subscription query directly to debug schema issues. Use this to find working subscription field names and structure. - Only whitelisted subscriptions are allowed (logFileSubscription, - containerStatsSubscription, cpuSubscription, memorySubscription, - arraySubscription, networkSubscription, dockerSubscription, - vmSubscription). + Only whitelisted schema fields are permitted (logFile, containerStats, + cpu, memory, array, network, docker, vm). Args: subscription_query: The GraphQL subscription query to test @@ -101,12 +101,13 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: Returns: Dict containing test results and response data """ - # Validate before any network I/O - sub_name = _validate_subscription_query(subscription_query) + # Validate before any network I/O (Bug 1 fix) + field_name = _validate_subscription_query(subscription_query) try: - logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'") + logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'") + # Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix) try: ws_url = build_ws_url() except ValueError as e: @@ -187,11 +188,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: # Get comprehensive status status = await subscription_manager.get_subscription_status() - # Analyze connection issues and error counts via the shared helper. - # This ensures "invalid_uri" and all other error states are counted - # consistently with the health tool's _diagnose_subscriptions path. + # Analyze connection issues and error counts via shared helper. + # Gates connection_issues on current failure state (Bug 5 fix). error_count, connection_issues = _analyze_subscription_status(status) + # Calculate WebSocket URL + ws_url_display: str | None = None + if UNRAID_API_URL: + try: + ws_url_display = build_ws_url() + except ValueError: + ws_url_display = None + # Add environment info with explicit typing diagnostic_info: dict[str, Any] = { "timestamp": datetime.now(UTC).isoformat(), @@ -200,7 +208,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: "max_reconnect_attempts": subscription_manager.max_reconnect_attempts, "unraid_api_url": safe_display_url(UNRAID_API_URL), "api_key_configured": bool(UNRAID_API_KEY), - "websocket_url": None, + "websocket_url": ws_url_display, }, "subscriptions": status, "summary": { @@ -217,10 +225,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: }, } - # Calculate WebSocket URL (stays None if UNRAID_API_URL not configured) - with contextlib.suppress(ValueError): - diagnostic_info["environment"]["websocket_url"] = build_ws_url() - # Add troubleshooting recommendations recommendations: list[str] = [] diff --git a/unraid_mcp/subscriptions/manager.py b/unraid_mcp/subscriptions/manager.py index e9453ca..fef4309 100644 --- a/unraid_mcp/subscriptions/manager.py +++ b/unraid_mcp/subscriptions/manager.py @@ -74,6 +74,57 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: return result +# Resource data size limits to prevent unbounded memory growth +_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1 MB +_MAX_RESOURCE_DATA_LINES = 5_000 + + +def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: + """Cap log content in subscription data to prevent unbounded memory growth. + + Returns a new dict — does NOT mutate the input. If any nested 'content' + field (from log subscriptions) exceeds the byte limit, truncate it to the + most recent _MAX_RESOURCE_DATA_LINES lines. + + The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES. + """ + result: dict[str, Any] = {} + for key, value in data.items(): + if isinstance(value, dict): + result[key] = _cap_log_content(value) + elif ( + key == "content" + and isinstance(value, str) + # Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap + and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES + ): + lines = value.splitlines() + original_line_count = len(lines) + + # Keep most recent lines first. + if len(lines) > _MAX_RESOURCE_DATA_LINES: + lines = lines[-_MAX_RESOURCE_DATA_LINES:] + + truncated = "\n".join(lines) + # Encode once and slice bytes instead of an O(n²) line-trim loop + encoded = truncated.encode("utf-8", errors="replace") + if len(encoded) > _MAX_RESOURCE_DATA_BYTES: + truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore") + # Strip partial first line that may have been cut mid-character + nl_pos = truncated.find("\n") + if nl_pos != -1: + truncated = truncated[nl_pos + 1 :] + + logger.warning( + f"[RESOURCE] Capped log content from {original_line_count} to " + f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)" + ) + result[key] = truncated + else: + result[key] = value + return result + + class SubscriptionManager: """Manages GraphQL subscriptions and converts them to MCP resources.""" diff --git a/unraid_mcp/subscriptions/resources.py b/unraid_mcp/subscriptions/resources.py index cd780a5..715cf78 100644 --- a/unraid_mcp/subscriptions/resources.py +++ b/unraid_mcp/subscriptions/resources.py @@ -48,7 +48,7 @@ async def autostart_subscriptions() -> None: logger.info("[AUTOSTART] Initiating subscription auto-start process...") try: - # Use the new SubscriptionManager auto-start method + # Use the SubscriptionManager auto-start method await subscription_manager.auto_start_all_subscriptions() logger.info("[AUTOSTART] Auto-start process completed successfully") except Exception as e: diff --git a/unraid_mcp/subscriptions/utils.py b/unraid_mcp/subscriptions/utils.py index 83c1c4d..c704e21 100644 --- a/unraid_mcp/subscriptions/utils.py +++ b/unraid_mcp/subscriptions/utils.py @@ -15,7 +15,7 @@ def build_ws_url() -> str: The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql"). Raises: - ValueError: If UNRAID_API_URL is not configured. + ValueError: If UNRAID_API_URL is not configured or has an unrecognised scheme. """ if not UNRAID_API_URL: raise ValueError("UNRAID_API_URL is not configured") @@ -24,8 +24,13 @@ def build_ws_url() -> str: ws_url = "wss://" + UNRAID_API_URL[len("https://") :] elif UNRAID_API_URL.startswith("http://"): ws_url = "ws://" + UNRAID_API_URL[len("http://") :] + elif UNRAID_API_URL.startswith(("ws://", "wss://")): + ws_url = UNRAID_API_URL # Already a WebSocket URL else: - ws_url = UNRAID_API_URL + raise ValueError( + f"UNRAID_API_URL must start with http://, https://, ws://, or wss://. " + f"Got: {UNRAID_API_URL[:20]}..." + ) if not ws_url.endswith("/graphql"): ws_url = ws_url.rstrip("/") + "/graphql" @@ -60,8 +65,8 @@ def _analyze_subscription_status( ) -> tuple[int, list[dict[str, Any]]]: """Analyze subscription status dict, returning error count and connection issues. - This is the canonical, shared implementation used by both the health tool - and the subscription diagnostics tool. + Only reports connection_issues for subscriptions that are currently in a + failure state (not recovered ones that happen to have a stale last_error). Args: status: Dict of subscription name -> status info from get_subscription_status(). @@ -69,15 +74,19 @@ def _analyze_subscription_status( Returns: Tuple of (error_count, connection_issues_list). """ + _error_states = frozenset( + {"error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"} + ) error_count = 0 connection_issues: list[dict[str, Any]] = [] for sub_name, sub_status in status.items(): runtime = sub_status.get("runtime", {}) conn_state = runtime.get("connection_state", "unknown") - if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"): + if conn_state in _error_states: error_count += 1 - if runtime.get("last_error"): + # Gate on current failure state so recovered subscriptions are not reported + if runtime.get("last_error") and conn_state in _error_states: connection_issues.append( { "subscription": sub_name,