mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 04:29:17 -07:00
merge: subscription fixes from worktree (validation, byte cap, autostart, URL scheme)
This commit is contained in:
@@ -6,7 +6,6 @@ development and debugging purposes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
@@ -25,35 +24,38 @@ from .resources import ensure_subscriptions_started
|
||||
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
||||
# Schema field names that appear inside the selection set of allowed subscriptions.
|
||||
# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the
|
||||
# opening "{", so we list the actual field names used in queries (e.g. "logFile"),
|
||||
# NOT the operation-level names (e.g. "logFileSubscription").
|
||||
_ALLOWED_SUBSCRIPTION_FIELDS = frozenset(
|
||||
{
|
||||
"logFileSubscription",
|
||||
"containerStatsSubscription",
|
||||
"cpuSubscription",
|
||||
"memorySubscription",
|
||||
"arraySubscription",
|
||||
"networkSubscription",
|
||||
"dockerSubscription",
|
||||
"vmSubscription",
|
||||
"logFile",
|
||||
"containerStats",
|
||||
"cpu",
|
||||
"memory",
|
||||
"array",
|
||||
"network",
|
||||
"docker",
|
||||
"vm",
|
||||
}
|
||||
)
|
||||
|
||||
# Pattern: must start with "subscription" and contain only a known subscription name.
|
||||
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
|
||||
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
|
||||
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
|
||||
# Pattern: must start with "subscription" keyword, then extract the first selected
|
||||
# field name (the word immediately after "{").
|
||||
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||
# Reject any query that contains a bare "mutation" or "query" operation keyword.
|
||||
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _validate_subscription_query(query: str) -> str:
|
||||
"""Validate that a subscription query is safe to execute.
|
||||
|
||||
Only allows subscription operations targeting whitelisted subscription names.
|
||||
Only allows subscription operations targeting whitelisted schema field names.
|
||||
Rejects any query containing mutation/query keywords.
|
||||
|
||||
Returns:
|
||||
The extracted subscription name.
|
||||
The extracted field name (e.g. "logFile").
|
||||
|
||||
Raises:
|
||||
ToolError: If the query fails validation.
|
||||
@@ -65,17 +67,17 @@ def _validate_subscription_query(query: str) -> str:
|
||||
if not match:
|
||||
raise ToolError(
|
||||
"Query rejected: must start with 'subscription' and contain a valid "
|
||||
"subscription operation. Example: subscription { logFileSubscription { ... } }"
|
||||
'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }'
|
||||
)
|
||||
|
||||
sub_name = match.group(1)
|
||||
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
|
||||
field_name = match.group(1)
|
||||
if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS:
|
||||
raise ToolError(
|
||||
f"Subscription '{sub_name}' is not allowed. "
|
||||
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
|
||||
f"Subscription field '{field_name}' is not allowed. "
|
||||
f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}"
|
||||
)
|
||||
|
||||
return sub_name
|
||||
return field_name
|
||||
|
||||
|
||||
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
@@ -90,10 +92,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
"""Test a GraphQL subscription query directly to debug schema issues.
|
||||
|
||||
Use this to find working subscription field names and structure.
|
||||
Only whitelisted subscriptions are allowed (logFileSubscription,
|
||||
containerStatsSubscription, cpuSubscription, memorySubscription,
|
||||
arraySubscription, networkSubscription, dockerSubscription,
|
||||
vmSubscription).
|
||||
Only whitelisted schema fields are permitted (logFile, containerStats,
|
||||
cpu, memory, array, network, docker, vm).
|
||||
|
||||
Args:
|
||||
subscription_query: The GraphQL subscription query to test
|
||||
@@ -101,12 +101,13 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
Returns:
|
||||
Dict containing test results and response data
|
||||
"""
|
||||
# Validate before any network I/O
|
||||
sub_name = _validate_subscription_query(subscription_query)
|
||||
# Validate before any network I/O (Bug 1 fix)
|
||||
field_name = _validate_subscription_query(subscription_query)
|
||||
|
||||
try:
|
||||
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'")
|
||||
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'")
|
||||
|
||||
# Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix)
|
||||
try:
|
||||
ws_url = build_ws_url()
|
||||
except ValueError as e:
|
||||
@@ -187,11 +188,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
# Get comprehensive status
|
||||
status = await subscription_manager.get_subscription_status()
|
||||
|
||||
# Analyze connection issues and error counts via the shared helper.
|
||||
# This ensures "invalid_uri" and all other error states are counted
|
||||
# consistently with the health tool's _diagnose_subscriptions path.
|
||||
# Analyze connection issues and error counts via shared helper.
|
||||
# Gates connection_issues on current failure state (Bug 5 fix).
|
||||
error_count, connection_issues = _analyze_subscription_status(status)
|
||||
|
||||
# Calculate WebSocket URL
|
||||
ws_url_display: str | None = None
|
||||
if UNRAID_API_URL:
|
||||
try:
|
||||
ws_url_display = build_ws_url()
|
||||
except ValueError:
|
||||
ws_url_display = None
|
||||
|
||||
# Add environment info with explicit typing
|
||||
diagnostic_info: dict[str, Any] = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
@@ -200,7 +208,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
||||
"unraid_api_url": safe_display_url(UNRAID_API_URL),
|
||||
"api_key_configured": bool(UNRAID_API_KEY),
|
||||
"websocket_url": None,
|
||||
"websocket_url": ws_url_display,
|
||||
},
|
||||
"subscriptions": status,
|
||||
"summary": {
|
||||
@@ -217,10 +225,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
},
|
||||
}
|
||||
|
||||
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured)
|
||||
with contextlib.suppress(ValueError):
|
||||
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
|
||||
|
||||
# Add troubleshooting recommendations
|
||||
recommendations: list[str] = []
|
||||
|
||||
|
||||
@@ -74,6 +74,57 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
# Resource data size limits to prevent unbounded memory growth
|
||||
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1 MB
|
||||
_MAX_RESOURCE_DATA_LINES = 5_000
|
||||
|
||||
|
||||
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cap log content in subscription data to prevent unbounded memory growth.
|
||||
|
||||
Returns a new dict — does NOT mutate the input. If any nested 'content'
|
||||
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
||||
most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||
|
||||
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = _cap_log_content(value)
|
||||
elif (
|
||||
key == "content"
|
||||
and isinstance(value, str)
|
||||
# Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap
|
||||
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
||||
):
|
||||
lines = value.splitlines()
|
||||
original_line_count = len(lines)
|
||||
|
||||
# Keep most recent lines first.
|
||||
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
||||
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
|
||||
|
||||
truncated = "\n".join(lines)
|
||||
# Encode once and slice bytes instead of an O(n²) line-trim loop
|
||||
encoded = truncated.encode("utf-8", errors="replace")
|
||||
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
|
||||
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
|
||||
# Strip partial first line that may have been cut mid-character
|
||||
nl_pos = truncated.find("\n")
|
||||
if nl_pos != -1:
|
||||
truncated = truncated[nl_pos + 1 :]
|
||||
|
||||
logger.warning(
|
||||
f"[RESOURCE] Capped log content from {original_line_count} to "
|
||||
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
|
||||
)
|
||||
result[key] = truncated
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
"""Manages GraphQL subscriptions and converts them to MCP resources."""
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ async def autostart_subscriptions() -> None:
|
||||
logger.info("[AUTOSTART] Initiating subscription auto-start process...")
|
||||
|
||||
try:
|
||||
# Use the new SubscriptionManager auto-start method
|
||||
# Use the SubscriptionManager auto-start method
|
||||
await subscription_manager.auto_start_all_subscriptions()
|
||||
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,7 +15,7 @@ def build_ws_url() -> str:
|
||||
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
|
||||
|
||||
Raises:
|
||||
ValueError: If UNRAID_API_URL is not configured.
|
||||
ValueError: If UNRAID_API_URL is not configured or has an unrecognised scheme.
|
||||
"""
|
||||
if not UNRAID_API_URL:
|
||||
raise ValueError("UNRAID_API_URL is not configured")
|
||||
@@ -24,8 +24,13 @@ def build_ws_url() -> str:
|
||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||
elif UNRAID_API_URL.startswith("http://"):
|
||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||
elif UNRAID_API_URL.startswith(("ws://", "wss://")):
|
||||
ws_url = UNRAID_API_URL # Already a WebSocket URL
|
||||
else:
|
||||
ws_url = UNRAID_API_URL
|
||||
raise ValueError(
|
||||
f"UNRAID_API_URL must start with http://, https://, ws://, or wss://. "
|
||||
f"Got: {UNRAID_API_URL[:20]}..."
|
||||
)
|
||||
|
||||
if not ws_url.endswith("/graphql"):
|
||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||
@@ -60,8 +65,8 @@ def _analyze_subscription_status(
|
||||
) -> tuple[int, list[dict[str, Any]]]:
|
||||
"""Analyze subscription status dict, returning error count and connection issues.
|
||||
|
||||
This is the canonical, shared implementation used by both the health tool
|
||||
and the subscription diagnostics tool.
|
||||
Only reports connection_issues for subscriptions that are currently in a
|
||||
failure state (not recovered ones that happen to have a stale last_error).
|
||||
|
||||
Args:
|
||||
status: Dict of subscription name -> status info from get_subscription_status().
|
||||
@@ -69,15 +74,19 @@ def _analyze_subscription_status(
|
||||
Returns:
|
||||
Tuple of (error_count, connection_issues_list).
|
||||
"""
|
||||
_error_states = frozenset(
|
||||
{"error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"}
|
||||
)
|
||||
error_count = 0
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
conn_state = runtime.get("connection_state", "unknown")
|
||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"):
|
||||
if conn_state in _error_states:
|
||||
error_count += 1
|
||||
if runtime.get("last_error"):
|
||||
# Gate on current failure state so recovered subscriptions are not reported
|
||||
if runtime.get("last_error") and conn_state in _error_states:
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
|
||||
Reference in New Issue
Block a user