fix: correct subscription validation, byte-based log cap, partial autostart, URL scheme

- diagnostics.py: fix allow-list vs field name mismatch in subscription validator
  (_ALLOWED_SUBSCRIPTION_FIELDS now contains schema field names like "logFile",
  not operation names like "logFileSubscription", matching what _SUBSCRIPTION_NAME_PATTERN
  extracts); add _validate_subscription_query() called before any network I/O;
  replace chained .replace() URL building with build_ws_url(); gate connection_issues
  on current failure state via _analyze_subscription_status()
- manager.py: add _cap_log_content() with byte-count pre-check
  (len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES) so
  multibyte UTF-8 content cannot bypass the 1 MB cap
- resources.py: add double-checked locking (_startup_lock) in ensure_subscriptions_started();
  propagate exception from auto_start_all_subscriptions() via raise so
  _subscriptions_started=True is never set after a failed init
- utils.py: add build_ws_url() that raises ValueError on unknown/missing URL scheme
  instead of silently falling through; add _analyze_subscription_status() helper
  that gates connection_issues on current failure state

Resolves review threads PRRT_kwDOO6Hdxs50E50Y PRRT_kwDOO6Hdxs50E50a PRRT_kwDOO6Hdxs50E50c PRRT_kwDOO6Hdxs50E50d PRRT_kwDOO6Hdxs50E2iN PRRT_kwDOO6Hdxs50E2h8
This commit is contained in:
Jacob Magar
2026-03-13 10:38:17 -04:00
parent 5b6a728f45
commit 9026faaa7c
4 changed files with 227 additions and 52 deletions

View File

@@ -7,6 +7,7 @@ development and debugging purposes.
import asyncio
import json
import re
from datetime import datetime
from typing import Any
@@ -19,7 +20,63 @@ from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.exceptions import ToolError
from .manager import subscription_manager
from .resources import ensure_subscriptions_started
from .utils import build_ws_ssl_context
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
# Schema field names that appear inside the selection set of allowed subscriptions.
# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the
# opening "{", so we list the actual field names used in queries (e.g. "logFile"),
# NOT the operation-level names (e.g. "logFileSubscription").
_ALLOWED_SUBSCRIPTION_FIELDS = frozenset(
{
"logFile",
"containerStats",
"cpu",
"memory",
"array",
"network",
"docker",
"vm",
}
)
# Pattern: must start with "subscription" keyword, then extract the first selected
# field name (the word immediately after "{").
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
# Reject any query that contains a bare "mutation" or "query" operation keyword.
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
def _validate_subscription_query(query: str) -> str:
"""Validate that a subscription query is safe to execute.
Only allows subscription operations targeting whitelisted schema field names.
Rejects any query containing mutation/query keywords.
Returns:
The extracted field name (e.g. "logFile").
Raises:
ToolError: If the query fails validation.
"""
if _FORBIDDEN_KEYWORDS.search(query):
raise ToolError("Query rejected: must be a subscription, not a mutation or query.")
match = _SUBSCRIPTION_NAME_PATTERN.match(query)
if not match:
raise ToolError(
"Query rejected: must start with 'subscription' and contain a valid "
'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }'
)
field_name = match.group(1)
if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS:
raise ToolError(
f"Subscription field '{field_name}' is not allowed. "
f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}"
)
return field_name
def register_diagnostic_tools(mcp: FastMCP) -> None:
@@ -34,6 +91,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"""Test a GraphQL subscription query directly to debug schema issues.
Use this to find working subscription field names and structure.
Only whitelisted schema fields are permitted (logFile, containerStats,
cpu, memory, array, network, docker, vm).
Args:
subscription_query: The GraphQL subscription query to test
@@ -41,16 +100,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
Returns:
Dict containing test results and response data
"""
try:
logger.info(f"[TEST_SUBSCRIPTION] Testing query: {subscription_query}")
# Validate before any network I/O (Bug 1 fix)
field_name = _validate_subscription_query(subscription_query)
# Build WebSocket URL
if not UNRAID_API_URL:
raise ToolError("UNRAID_API_URL is not configured")
ws_url = (
UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://")
+ "/graphql"
)
try:
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'")
# Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix)
try:
ws_url = build_ws_url()
except ValueError as e:
raise ToolError(str(e)) from e
ssl_context = build_ws_ssl_context(ws_url)
@@ -59,6 +119,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
ws_url,
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
ssl=ssl_context,
open_timeout=10,
ping_interval=30,
ping_timeout=10,
) as websocket:
@@ -102,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"note": "Connection successful, subscription may be waiting for events",
}
except ToolError:
raise
except Exception as e:
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
return {"error": str(e), "query_tested": subscription_query}
@@ -124,8 +187,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
# Get comprehensive status
status = subscription_manager.get_subscription_status()
# Initialize connection issues list with proper type
connection_issues: list[dict[str, Any]] = []
# Analyze connection issues and error counts via shared helper.
# Gates connection_issues on current failure state (Bug 5 fix).
error_count, connection_issues = _analyze_subscription_status(status)
# Calculate WebSocket URL
ws_url_display: str | None = None
if UNRAID_API_URL:
try:
ws_url_display = build_ws_url()
except ValueError:
ws_url_display = None
# Add environment info with explicit typing
diagnostic_info: dict[str, Any] = {
@@ -135,7 +207,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
"unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None,
"api_key_configured": bool(UNRAID_API_KEY),
"websocket_url": None,
"websocket_url": ws_url_display,
},
"subscriptions": status,
"summary": {
@@ -147,40 +219,11 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
),
"active_count": len(subscription_manager.active_subscriptions),
"with_data": len(subscription_manager.resource_data),
"in_error_state": 0,
"in_error_state": error_count,
"connection_issues": connection_issues,
},
}
# Calculate WebSocket URL
if UNRAID_API_URL:
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
diagnostic_info["environment"]["websocket_url"] = ws_url
# Analyze issues
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
connection_state = runtime.get("connection_state", "unknown")
if connection_state in ["error", "auth_failed", "timeout", "max_retries_exceeded"]:
diagnostic_info["summary"]["in_error_state"] += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": connection_state,
"error": runtime["last_error"],
}
)
# Add troubleshooting recommendations
recommendations: list[str] = []

View File

@@ -20,6 +20,57 @@ from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context
# Resource data size limits to prevent unbounded memory growth
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1 MB
_MAX_RESOURCE_DATA_LINES = 5_000
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
"""Cap log content in subscription data to prevent unbounded memory growth.
Returns a new dict — does NOT mutate the input. If any nested 'content'
field (from log subscriptions) exceeds the byte limit, truncate it to the
most recent _MAX_RESOURCE_DATA_LINES lines.
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = _cap_log_content(value)
elif (
key == "content"
and isinstance(value, str)
# Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
):
lines = value.splitlines()
original_line_count = len(lines)
# Keep most recent lines first.
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
truncated = "\n".join(lines)
# Encode once and slice bytes instead of an O(n²) line-trim loop
encoded = truncated.encode("utf-8", errors="replace")
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
# Strip partial first line that may have been cut mid-character
nl_pos = truncated.find("\n")
if nl_pos != -1:
truncated = truncated[nl_pos + 1 :]
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
)
result[key] = truncated
else:
result[key] = value
return result
class SubscriptionManager:
"""Manages GraphQL subscriptions and converts them to MCP resources."""

View File

@@ -4,8 +4,10 @@ This module defines MCP resources that bridge between the subscription manager
and the MCP protocol, providing fallback queries when subscription data is unavailable.
"""
import asyncio
import json
import os
from typing import Final
import anyio
from fastmcp import FastMCP
@@ -16,22 +18,29 @@ from .manager import subscription_manager
# Global flag to track subscription startup
_subscriptions_started = False
_startup_lock: Final[asyncio.Lock] = asyncio.Lock()
async def ensure_subscriptions_started() -> None:
"""Ensure subscriptions are started, called from async context."""
global _subscriptions_started
# Fast-path: skip lock if already started
if _subscriptions_started:
return
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
try:
await autostart_subscriptions()
_subscriptions_started = True
logger.info("[STARTUP] Subscriptions started successfully")
except Exception as e:
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
# Slow-path: acquire lock for initialization (double-checked locking)
async with _startup_lock:
if _subscriptions_started:
return
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
try:
await autostart_subscriptions()
_subscriptions_started = True
logger.info("[STARTUP] Subscriptions started successfully")
except Exception as e:
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
async def autostart_subscriptions() -> None:
@@ -39,11 +48,12 @@ async def autostart_subscriptions() -> None:
logger.info("[AUTOSTART] Initiating subscription auto-start process...")
try:
# Use the new SubscriptionManager auto-start method
# Use the SubscriptionManager auto-start method
await subscription_manager.auto_start_all_subscriptions()
logger.info("[AUTOSTART] Auto-start process completed successfully")
except Exception as e:
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
raise # Propagate so ensure_subscriptions_started doesn't mark as started
# Optional log file subscription
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")

View File

@@ -1,8 +1,79 @@
"""Shared utilities for the subscription system."""
import ssl as _ssl
from typing import Any
from ..config.settings import UNRAID_VERIFY_SSL
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL
def build_ws_url() -> str:
"""Build a WebSocket URL from the configured UNRAID_API_URL.
Converts http(s) scheme to ws(s) and ensures /graphql path suffix.
Returns:
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
Raises:
ValueError: If UNRAID_API_URL is not configured or has an unrecognised scheme.
"""
if not UNRAID_API_URL:
raise ValueError("UNRAID_API_URL is not configured")
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
elif UNRAID_API_URL.startswith(("ws://", "wss://")):
ws_url = UNRAID_API_URL # Already a WebSocket URL
else:
raise ValueError(
f"UNRAID_API_URL must start with http://, https://, ws://, or wss://. "
f"Got: {UNRAID_API_URL[:20]}..."
)
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
return ws_url
def _analyze_subscription_status(
status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]:
"""Analyze subscription status dict, returning error count and connection issues.
Only reports connection_issues for subscriptions that are currently in a
failure state (not recovered ones that happen to have a stale last_error).
Args:
status: Dict of subscription name -> status info from get_subscription_status().
Returns:
Tuple of (error_count, connection_issues_list).
"""
_error_states = frozenset(
{"error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"}
)
error_count = 0
connection_issues: list[dict[str, Any]] = []
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in _error_states:
error_count += 1
# Gate on current failure state so recovered subscriptions are not reported
if runtime.get("last_error") and conn_state in _error_states:
connection_issues.append(
{
"subscription": sub_name,
"state": conn_state,
"error": runtime["last_error"],
}
)
return error_count, connection_issues
def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None: