mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 04:29:17 -07:00
merge: subscription fixes from worktree (validation, byte cap, autostart, URL scheme)
This commit is contained in:
@@ -6,7 +6,6 @@ development and debugging purposes.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -25,35 +24,38 @@ from .resources import ensure_subscriptions_started
|
|||||||
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
|
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
|
||||||
|
|
||||||
|
|
||||||
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
# Schema field names that appear inside the selection set of allowed subscriptions.
|
||||||
|
# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the
|
||||||
|
# opening "{", so we list the actual field names used in queries (e.g. "logFile"),
|
||||||
|
# NOT the operation-level names (e.g. "logFileSubscription").
|
||||||
|
_ALLOWED_SUBSCRIPTION_FIELDS = frozenset(
|
||||||
{
|
{
|
||||||
"logFileSubscription",
|
"logFile",
|
||||||
"containerStatsSubscription",
|
"containerStats",
|
||||||
"cpuSubscription",
|
"cpu",
|
||||||
"memorySubscription",
|
"memory",
|
||||||
"arraySubscription",
|
"array",
|
||||||
"networkSubscription",
|
"network",
|
||||||
"dockerSubscription",
|
"docker",
|
||||||
"vmSubscription",
|
"vm",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pattern: must start with "subscription" and contain only a known subscription name.
|
# Pattern: must start with "subscription" keyword, then extract the first selected
|
||||||
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
|
# field name (the word immediately after "{").
|
||||||
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
|
|
||||||
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
|
|
||||||
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||||
|
# Reject any query that contains a bare "mutation" or "query" operation keyword.
|
||||||
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def _validate_subscription_query(query: str) -> str:
|
def _validate_subscription_query(query: str) -> str:
|
||||||
"""Validate that a subscription query is safe to execute.
|
"""Validate that a subscription query is safe to execute.
|
||||||
|
|
||||||
Only allows subscription operations targeting whitelisted subscription names.
|
Only allows subscription operations targeting whitelisted schema field names.
|
||||||
Rejects any query containing mutation/query keywords.
|
Rejects any query containing mutation/query keywords.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The extracted subscription name.
|
The extracted field name (e.g. "logFile").
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ToolError: If the query fails validation.
|
ToolError: If the query fails validation.
|
||||||
@@ -65,17 +67,17 @@ def _validate_subscription_query(query: str) -> str:
|
|||||||
if not match:
|
if not match:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
"Query rejected: must start with 'subscription' and contain a valid "
|
"Query rejected: must start with 'subscription' and contain a valid "
|
||||||
"subscription operation. Example: subscription { logFileSubscription { ... } }"
|
'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }'
|
||||||
)
|
)
|
||||||
|
|
||||||
sub_name = match.group(1)
|
field_name = match.group(1)
|
||||||
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
|
if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"Subscription '{sub_name}' is not allowed. "
|
f"Subscription field '{field_name}' is not allowed. "
|
||||||
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
|
f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return sub_name
|
return field_name
|
||||||
|
|
||||||
|
|
||||||
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||||
@@ -90,10 +92,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
"""Test a GraphQL subscription query directly to debug schema issues.
|
"""Test a GraphQL subscription query directly to debug schema issues.
|
||||||
|
|
||||||
Use this to find working subscription field names and structure.
|
Use this to find working subscription field names and structure.
|
||||||
Only whitelisted subscriptions are allowed (logFileSubscription,
|
Only whitelisted schema fields are permitted (logFile, containerStats,
|
||||||
containerStatsSubscription, cpuSubscription, memorySubscription,
|
cpu, memory, array, network, docker, vm).
|
||||||
arraySubscription, networkSubscription, dockerSubscription,
|
|
||||||
vmSubscription).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subscription_query: The GraphQL subscription query to test
|
subscription_query: The GraphQL subscription query to test
|
||||||
@@ -101,12 +101,13 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing test results and response data
|
Dict containing test results and response data
|
||||||
"""
|
"""
|
||||||
# Validate before any network I/O
|
# Validate before any network I/O (Bug 1 fix)
|
||||||
sub_name = _validate_subscription_query(subscription_query)
|
field_name = _validate_subscription_query(subscription_query)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'")
|
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'")
|
||||||
|
|
||||||
|
# Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix)
|
||||||
try:
|
try:
|
||||||
ws_url = build_ws_url()
|
ws_url = build_ws_url()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -187,11 +188,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
# Get comprehensive status
|
# Get comprehensive status
|
||||||
status = await subscription_manager.get_subscription_status()
|
status = await subscription_manager.get_subscription_status()
|
||||||
|
|
||||||
# Analyze connection issues and error counts via the shared helper.
|
# Analyze connection issues and error counts via shared helper.
|
||||||
# This ensures "invalid_uri" and all other error states are counted
|
# Gates connection_issues on current failure state (Bug 5 fix).
|
||||||
# consistently with the health tool's _diagnose_subscriptions path.
|
|
||||||
error_count, connection_issues = _analyze_subscription_status(status)
|
error_count, connection_issues = _analyze_subscription_status(status)
|
||||||
|
|
||||||
|
# Calculate WebSocket URL
|
||||||
|
ws_url_display: str | None = None
|
||||||
|
if UNRAID_API_URL:
|
||||||
|
try:
|
||||||
|
ws_url_display = build_ws_url()
|
||||||
|
except ValueError:
|
||||||
|
ws_url_display = None
|
||||||
|
|
||||||
# Add environment info with explicit typing
|
# Add environment info with explicit typing
|
||||||
diagnostic_info: dict[str, Any] = {
|
diagnostic_info: dict[str, Any] = {
|
||||||
"timestamp": datetime.now(UTC).isoformat(),
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
@@ -200,7 +208,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
||||||
"unraid_api_url": safe_display_url(UNRAID_API_URL),
|
"unraid_api_url": safe_display_url(UNRAID_API_URL),
|
||||||
"api_key_configured": bool(UNRAID_API_KEY),
|
"api_key_configured": bool(UNRAID_API_KEY),
|
||||||
"websocket_url": None,
|
"websocket_url": ws_url_display,
|
||||||
},
|
},
|
||||||
"subscriptions": status,
|
"subscriptions": status,
|
||||||
"summary": {
|
"summary": {
|
||||||
@@ -217,10 +225,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured)
|
|
||||||
with contextlib.suppress(ValueError):
|
|
||||||
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
|
|
||||||
|
|
||||||
# Add troubleshooting recommendations
|
# Add troubleshooting recommendations
|
||||||
recommendations: list[str] = []
|
recommendations: list[str] = []
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,57 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Resource data size limits to prevent unbounded memory growth
|
||||||
|
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1 MB
|
||||||
|
_MAX_RESOURCE_DATA_LINES = 5_000
|
||||||
|
|
||||||
|
|
||||||
|
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Cap log content in subscription data to prevent unbounded memory growth.
|
||||||
|
|
||||||
|
Returns a new dict — does NOT mutate the input. If any nested 'content'
|
||||||
|
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
||||||
|
most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||||
|
|
||||||
|
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
result[key] = _cap_log_content(value)
|
||||||
|
elif (
|
||||||
|
key == "content"
|
||||||
|
and isinstance(value, str)
|
||||||
|
# Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap
|
||||||
|
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
||||||
|
):
|
||||||
|
lines = value.splitlines()
|
||||||
|
original_line_count = len(lines)
|
||||||
|
|
||||||
|
# Keep most recent lines first.
|
||||||
|
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
||||||
|
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
|
||||||
|
|
||||||
|
truncated = "\n".join(lines)
|
||||||
|
# Encode once and slice bytes instead of an O(n²) line-trim loop
|
||||||
|
encoded = truncated.encode("utf-8", errors="replace")
|
||||||
|
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
|
||||||
|
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
|
||||||
|
# Strip partial first line that may have been cut mid-character
|
||||||
|
nl_pos = truncated.find("\n")
|
||||||
|
if nl_pos != -1:
|
||||||
|
truncated = truncated[nl_pos + 1 :]
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"[RESOURCE] Capped log content from {original_line_count} to "
|
||||||
|
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
|
||||||
|
)
|
||||||
|
result[key] = truncated
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionManager:
|
class SubscriptionManager:
|
||||||
"""Manages GraphQL subscriptions and converts them to MCP resources."""
|
"""Manages GraphQL subscriptions and converts them to MCP resources."""
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ async def autostart_subscriptions() -> None:
|
|||||||
logger.info("[AUTOSTART] Initiating subscription auto-start process...")
|
logger.info("[AUTOSTART] Initiating subscription auto-start process...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the new SubscriptionManager auto-start method
|
# Use the SubscriptionManager auto-start method
|
||||||
await subscription_manager.auto_start_all_subscriptions()
|
await subscription_manager.auto_start_all_subscriptions()
|
||||||
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def build_ws_url() -> str:
|
|||||||
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
|
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If UNRAID_API_URL is not configured.
|
ValueError: If UNRAID_API_URL is not configured or has an unrecognised scheme.
|
||||||
"""
|
"""
|
||||||
if not UNRAID_API_URL:
|
if not UNRAID_API_URL:
|
||||||
raise ValueError("UNRAID_API_URL is not configured")
|
raise ValueError("UNRAID_API_URL is not configured")
|
||||||
@@ -24,8 +24,13 @@ def build_ws_url() -> str:
|
|||||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||||
elif UNRAID_API_URL.startswith("http://"):
|
elif UNRAID_API_URL.startswith("http://"):
|
||||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||||
|
elif UNRAID_API_URL.startswith(("ws://", "wss://")):
|
||||||
|
ws_url = UNRAID_API_URL # Already a WebSocket URL
|
||||||
else:
|
else:
|
||||||
ws_url = UNRAID_API_URL
|
raise ValueError(
|
||||||
|
f"UNRAID_API_URL must start with http://, https://, ws://, or wss://. "
|
||||||
|
f"Got: {UNRAID_API_URL[:20]}..."
|
||||||
|
)
|
||||||
|
|
||||||
if not ws_url.endswith("/graphql"):
|
if not ws_url.endswith("/graphql"):
|
||||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||||
@@ -60,8 +65,8 @@ def _analyze_subscription_status(
|
|||||||
) -> tuple[int, list[dict[str, Any]]]:
|
) -> tuple[int, list[dict[str, Any]]]:
|
||||||
"""Analyze subscription status dict, returning error count and connection issues.
|
"""Analyze subscription status dict, returning error count and connection issues.
|
||||||
|
|
||||||
This is the canonical, shared implementation used by both the health tool
|
Only reports connection_issues for subscriptions that are currently in a
|
||||||
and the subscription diagnostics tool.
|
failure state (not recovered ones that happen to have a stale last_error).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
status: Dict of subscription name -> status info from get_subscription_status().
|
status: Dict of subscription name -> status info from get_subscription_status().
|
||||||
@@ -69,15 +74,19 @@ def _analyze_subscription_status(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (error_count, connection_issues_list).
|
Tuple of (error_count, connection_issues_list).
|
||||||
"""
|
"""
|
||||||
|
_error_states = frozenset(
|
||||||
|
{"error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"}
|
||||||
|
)
|
||||||
error_count = 0
|
error_count = 0
|
||||||
connection_issues: list[dict[str, Any]] = []
|
connection_issues: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for sub_name, sub_status in status.items():
|
for sub_name, sub_status in status.items():
|
||||||
runtime = sub_status.get("runtime", {})
|
runtime = sub_status.get("runtime", {})
|
||||||
conn_state = runtime.get("connection_state", "unknown")
|
conn_state = runtime.get("connection_state", "unknown")
|
||||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"):
|
if conn_state in _error_states:
|
||||||
error_count += 1
|
error_count += 1
|
||||||
if runtime.get("last_error"):
|
# Gate on current failure state so recovered subscriptions are not reported
|
||||||
|
if runtime.get("last_error") and conn_state in _error_states:
|
||||||
connection_issues.append(
|
connection_issues.append(
|
||||||
{
|
{
|
||||||
"subscription": sub_name,
|
"subscription": sub_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user