merge: subscription fixes from worktree (validation, byte cap, autostart, URL scheme)

This commit is contained in:
Jacob Magar
2026-03-13 10:48:05 -04:00
4 changed files with 108 additions and 44 deletions

View File

@@ -6,7 +6,6 @@ development and debugging purposes.
"""
import asyncio
import contextlib
import json
import re
from datetime import UTC, datetime
@@ -25,35 +24,38 @@ from .resources import ensure_subscriptions_started
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
# Schema field names that appear inside the selection set of allowed subscriptions.
# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the
# opening "{", so we list the actual field names used in queries (e.g. "logFile"),
# NOT the operation-level names (e.g. "logFileSubscription").
_ALLOWED_SUBSCRIPTION_FIELDS = frozenset(
{
"logFileSubscription",
"containerStatsSubscription",
"cpuSubscription",
"memorySubscription",
"arraySubscription",
"networkSubscription",
"dockerSubscription",
"vmSubscription",
"logFile",
"containerStats",
"cpu",
"memory",
"array",
"network",
"docker",
"vm",
}
)
# Pattern: must start with "subscription" and contain only a known subscription name.
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
# Pattern: must start with "subscription" keyword, then extract the first selected
# field name (the word immediately after "{").
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
# Reject any query that contains a bare "mutation" or "query" operation keyword.
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
def _validate_subscription_query(query: str) -> str:
"""Validate that a subscription query is safe to execute.
Only allows subscription operations targeting whitelisted subscription names.
Only allows subscription operations targeting whitelisted schema field names.
Rejects any query containing mutation/query keywords.
Returns:
The extracted subscription name.
The extracted field name (e.g. "logFile").
Raises:
ToolError: If the query fails validation.
@@ -65,17 +67,17 @@ def _validate_subscription_query(query: str) -> str:
if not match:
raise ToolError(
"Query rejected: must start with 'subscription' and contain a valid "
"subscription operation. Example: subscription { logFileSubscription { ... } }"
'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }'
)
sub_name = match.group(1)
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
field_name = match.group(1)
if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS:
raise ToolError(
f"Subscription '{sub_name}' is not allowed. "
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
f"Subscription field '{field_name}' is not allowed. "
f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}"
)
return sub_name
return field_name
def register_diagnostic_tools(mcp: FastMCP) -> None:
@@ -90,10 +92,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"""Test a GraphQL subscription query directly to debug schema issues.
Use this to find working subscription field names and structure.
Only whitelisted subscriptions are allowed (logFileSubscription,
containerStatsSubscription, cpuSubscription, memorySubscription,
arraySubscription, networkSubscription, dockerSubscription,
vmSubscription).
Only whitelisted schema fields are permitted (logFile, containerStats,
cpu, memory, array, network, docker, vm).
Args:
subscription_query: The GraphQL subscription query to test
@@ -101,12 +101,13 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
Returns:
Dict containing test results and response data
"""
# Validate before any network I/O
sub_name = _validate_subscription_query(subscription_query)
# Validate before any network I/O (Bug 1 fix)
field_name = _validate_subscription_query(subscription_query)
try:
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'")
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'")
# Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix)
try:
ws_url = build_ws_url()
except ValueError as e:
@@ -187,11 +188,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
# Get comprehensive status
status = await subscription_manager.get_subscription_status()
# Analyze connection issues and error counts via the shared helper.
# This ensures "invalid_uri" and all other error states are counted
# consistently with the health tool's _diagnose_subscriptions path.
# Analyze connection issues and error counts via shared helper.
# Gates connection_issues on current failure state (Bug 5 fix).
error_count, connection_issues = _analyze_subscription_status(status)
# Calculate WebSocket URL
ws_url_display: str | None = None
if UNRAID_API_URL:
try:
ws_url_display = build_ws_url()
except ValueError:
ws_url_display = None
# Add environment info with explicit typing
diagnostic_info: dict[str, Any] = {
"timestamp": datetime.now(UTC).isoformat(),
@@ -200,7 +208,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
"unraid_api_url": safe_display_url(UNRAID_API_URL),
"api_key_configured": bool(UNRAID_API_KEY),
"websocket_url": None,
"websocket_url": ws_url_display,
},
"subscriptions": status,
"summary": {
@@ -217,10 +225,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
},
}
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured)
with contextlib.suppress(ValueError):
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
# Add troubleshooting recommendations
recommendations: list[str] = []