merge: subscription fixes from worktree (validation, byte cap, autostart, URL scheme)

This commit is contained in:
Jacob Magar
2026-03-13 10:48:05 -04:00
4 changed files with 108 additions and 44 deletions

View File

@@ -6,7 +6,6 @@ development and debugging purposes.
"""
import asyncio
import contextlib
import json
import re
from datetime import UTC, datetime
@@ -25,35 +24,38 @@ from .resources import ensure_subscriptions_started
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
# Schema field names that appear inside the selection set of allowed subscriptions.
# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the
# opening "{", so we list the actual field names used in queries (e.g. "logFile"),
# NOT the operation-level names (e.g. "logFileSubscription").
_ALLOWED_SUBSCRIPTION_FIELDS = frozenset(
{
"logFileSubscription",
"containerStatsSubscription",
"cpuSubscription",
"memorySubscription",
"arraySubscription",
"networkSubscription",
"dockerSubscription",
"vmSubscription",
"logFile",
"containerStats",
"cpu",
"memory",
"array",
"network",
"docker",
"vm",
}
)
# Pattern: must start with "subscription" and contain only a known subscription name.
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
# Pattern: must start with "subscription" keyword, then extract the first selected
# field name (the word immediately after "{").
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
# Reject any query that contains a bare "mutation" or "query" operation keyword.
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
def _validate_subscription_query(query: str) -> str:
"""Validate that a subscription query is safe to execute.
Only allows subscription operations targeting whitelisted subscription names.
Only allows subscription operations targeting whitelisted schema field names.
Rejects any query containing mutation/query keywords.
Returns:
The extracted subscription name.
The extracted field name (e.g. "logFile").
Raises:
ToolError: If the query fails validation.
@@ -65,17 +67,17 @@ def _validate_subscription_query(query: str) -> str:
if not match:
raise ToolError(
"Query rejected: must start with 'subscription' and contain a valid "
"subscription operation. Example: subscription { logFileSubscription { ... } }"
'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }'
)
sub_name = match.group(1)
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
field_name = match.group(1)
if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS:
raise ToolError(
f"Subscription '{sub_name}' is not allowed. "
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
f"Subscription field '{field_name}' is not allowed. "
f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}"
)
return sub_name
return field_name
def register_diagnostic_tools(mcp: FastMCP) -> None:
@@ -90,10 +92,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"""Test a GraphQL subscription query directly to debug schema issues.
Use this to find working subscription field names and structure.
Only whitelisted subscriptions are allowed (logFileSubscription,
containerStatsSubscription, cpuSubscription, memorySubscription,
arraySubscription, networkSubscription, dockerSubscription,
vmSubscription).
Only whitelisted schema fields are permitted (logFile, containerStats,
cpu, memory, array, network, docker, vm).
Args:
subscription_query: The GraphQL subscription query to test
@@ -101,12 +101,13 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
Returns:
Dict containing test results and response data
"""
# Validate before any network I/O
sub_name = _validate_subscription_query(subscription_query)
# Validate before any network I/O (Bug 1 fix)
field_name = _validate_subscription_query(subscription_query)
try:
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'")
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'")
# Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix)
try:
ws_url = build_ws_url()
except ValueError as e:
@@ -187,11 +188,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
# Get comprehensive status
status = await subscription_manager.get_subscription_status()
# Analyze connection issues and error counts via the shared helper.
# This ensures "invalid_uri" and all other error states are counted
# consistently with the health tool's _diagnose_subscriptions path.
# Analyze connection issues and error counts via shared helper.
# Gates connection_issues on current failure state (Bug 5 fix).
error_count, connection_issues = _analyze_subscription_status(status)
# Calculate WebSocket URL
ws_url_display: str | None = None
if UNRAID_API_URL:
try:
ws_url_display = build_ws_url()
except ValueError:
ws_url_display = None
# Add environment info with explicit typing
diagnostic_info: dict[str, Any] = {
"timestamp": datetime.now(UTC).isoformat(),
@@ -200,7 +208,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
"unraid_api_url": safe_display_url(UNRAID_API_URL),
"api_key_configured": bool(UNRAID_API_KEY),
"websocket_url": None,
"websocket_url": ws_url_display,
},
"subscriptions": status,
"summary": {
@@ -217,10 +225,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
},
}
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured)
with contextlib.suppress(ValueError):
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
# Add troubleshooting recommendations
recommendations: list[str] = []

View File

@@ -74,6 +74,57 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
return result
# Resource data size limits to prevent unbounded memory growth
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1 MB
_MAX_RESOURCE_DATA_LINES = 5_000
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
"""Cap log content in subscription data to prevent unbounded memory growth.
Returns a new dict — does NOT mutate the input. If any nested 'content'
field (from log subscriptions) exceeds the byte limit, truncate it to the
most recent _MAX_RESOURCE_DATA_LINES lines.
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = _cap_log_content(value)
elif (
key == "content"
and isinstance(value, str)
# Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
):
lines = value.splitlines()
original_line_count = len(lines)
# Keep most recent lines first.
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
truncated = "\n".join(lines)
# Encode once and slice bytes instead of an O(n²) line-trim loop
encoded = truncated.encode("utf-8", errors="replace")
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
# Strip partial first line that may have been cut mid-character
nl_pos = truncated.find("\n")
if nl_pos != -1:
truncated = truncated[nl_pos + 1 :]
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
)
result[key] = truncated
else:
result[key] = value
return result
class SubscriptionManager:
"""Manages GraphQL subscriptions and converts them to MCP resources."""

View File

@@ -48,7 +48,7 @@ async def autostart_subscriptions() -> None:
logger.info("[AUTOSTART] Initiating subscription auto-start process...")
try:
# Use the new SubscriptionManager auto-start method
# Use the SubscriptionManager auto-start method
await subscription_manager.auto_start_all_subscriptions()
logger.info("[AUTOSTART] Auto-start process completed successfully")
except Exception as e:

View File

@@ -15,7 +15,7 @@ def build_ws_url() -> str:
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
Raises:
ValueError: If UNRAID_API_URL is not configured.
ValueError: If UNRAID_API_URL is not configured or has an unrecognised scheme.
"""
if not UNRAID_API_URL:
raise ValueError("UNRAID_API_URL is not configured")
@@ -24,8 +24,13 @@ def build_ws_url() -> str:
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
elif UNRAID_API_URL.startswith(("ws://", "wss://")):
ws_url = UNRAID_API_URL # Already a WebSocket URL
else:
ws_url = UNRAID_API_URL
raise ValueError(
f"UNRAID_API_URL must start with http://, https://, ws://, or wss://. "
f"Got: {UNRAID_API_URL[:20]}..."
)
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
@@ -60,8 +65,8 @@ def _analyze_subscription_status(
) -> tuple[int, list[dict[str, Any]]]:
"""Analyze subscription status dict, returning error count and connection issues.
This is the canonical, shared implementation used by both the health tool
and the subscription diagnostics tool.
Only reports connection_issues for subscriptions that are currently in a
failure state (not recovered ones that happen to have a stale last_error).
Args:
status: Dict of subscription name -> status info from get_subscription_status().
@@ -69,15 +74,19 @@ def _analyze_subscription_status(
Returns:
Tuple of (error_count, connection_issues_list).
"""
_error_states = frozenset(
{"error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"}
)
error_count = 0
connection_issues: list[dict[str, Any]] = []
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"):
if conn_state in _error_states:
error_count += 1
if runtime.get("last_error"):
# Gate on current failure state so recovered subscriptions are not reported
if runtime.get("last_error") and conn_state in _error_states:
connection_issues.append(
{
"subscription": sub_name,