mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 12:39:24 -07:00
fix: correct subscription validation, byte-based log cap, partial autostart, URL scheme
- diagnostics.py: fix allow-list vs field name mismatch in subscription validator
(_ALLOWED_SUBSCRIPTION_FIELDS now contains schema field names like "logFile",
not operation names like "logFileSubscription", matching what _SUBSCRIPTION_NAME_PATTERN
extracts); add _validate_subscription_query() called before any network I/O;
replace chained .replace() URL building with build_ws_url(); gate connection_issues
on current failure state via _analyze_subscription_status()
- manager.py: add _cap_log_content() with byte-count pre-check
(len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES) so
multibyte UTF-8 content cannot bypass the 1 MB cap
- resources.py: add double-checked locking (_startup_lock) in ensure_subscriptions_started();
propagate exception from auto_start_all_subscriptions() via raise so
_subscriptions_started=True is never set after a failed init
- utils.py: add build_ws_url() that raises ValueError on unknown/missing URL scheme
instead of silently falling through; add _analyze_subscription_status() helper
that gates connection_issues on current failure state
Resolves review threads PRRT_kwDOO6Hdxs50E50Y PRRT_kwDOO6Hdxs50E50a PRRT_kwDOO6Hdxs50E50c PRRT_kwDOO6Hdxs50E50d PRRT_kwDOO6Hdxs50E2iN PRRT_kwDOO6Hdxs50E2h8
This commit is contained in:
@@ -7,6 +7,7 @@ development and debugging purposes.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
@@ -19,7 +20,63 @@ from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
||||
from ..core.exceptions import ToolError
|
||||
from .manager import subscription_manager
|
||||
from .resources import ensure_subscriptions_started
|
||||
from .utils import build_ws_ssl_context
|
||||
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
# Schema field names that appear inside the selection set of allowed subscriptions.
|
||||
# The regex _SUBSCRIPTION_NAME_PATTERN extracts the first identifier after the
|
||||
# opening "{", so we list the actual field names used in queries (e.g. "logFile"),
|
||||
# NOT the operation-level names (e.g. "logFileSubscription").
|
||||
_ALLOWED_SUBSCRIPTION_FIELDS = frozenset(
|
||||
{
|
||||
"logFile",
|
||||
"containerStats",
|
||||
"cpu",
|
||||
"memory",
|
||||
"array",
|
||||
"network",
|
||||
"docker",
|
||||
"vm",
|
||||
}
|
||||
)
|
||||
|
||||
# Pattern: must start with "subscription" keyword, then extract the first selected
|
||||
# field name (the word immediately after "{").
|
||||
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||
# Reject any query that contains a bare "mutation" or "query" operation keyword.
|
||||
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _validate_subscription_query(query: str) -> str:
|
||||
"""Validate that a subscription query is safe to execute.
|
||||
|
||||
Only allows subscription operations targeting whitelisted schema field names.
|
||||
Rejects any query containing mutation/query keywords.
|
||||
|
||||
Returns:
|
||||
The extracted field name (e.g. "logFile").
|
||||
|
||||
Raises:
|
||||
ToolError: If the query fails validation.
|
||||
"""
|
||||
if _FORBIDDEN_KEYWORDS.search(query):
|
||||
raise ToolError("Query rejected: must be a subscription, not a mutation or query.")
|
||||
|
||||
match = _SUBSCRIPTION_NAME_PATTERN.match(query)
|
||||
if not match:
|
||||
raise ToolError(
|
||||
"Query rejected: must start with 'subscription' and contain a valid "
|
||||
'subscription field. Example: subscription { logFile(path: "/var/log/syslog") { content } }'
|
||||
)
|
||||
|
||||
field_name = match.group(1)
|
||||
if field_name not in _ALLOWED_SUBSCRIPTION_FIELDS:
|
||||
raise ToolError(
|
||||
f"Subscription field '{field_name}' is not allowed. "
|
||||
f"Allowed fields: {sorted(_ALLOWED_SUBSCRIPTION_FIELDS)}"
|
||||
)
|
||||
|
||||
return field_name
|
||||
|
||||
|
||||
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
@@ -34,6 +91,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
"""Test a GraphQL subscription query directly to debug schema issues.
|
||||
|
||||
Use this to find working subscription field names and structure.
|
||||
Only whitelisted schema fields are permitted (logFile, containerStats,
|
||||
cpu, memory, array, network, docker, vm).
|
||||
|
||||
Args:
|
||||
subscription_query: The GraphQL subscription query to test
|
||||
@@ -41,16 +100,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
Returns:
|
||||
Dict containing test results and response data
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[TEST_SUBSCRIPTION] Testing query: {subscription_query}")
|
||||
# Validate before any network I/O (Bug 1 fix)
|
||||
field_name = _validate_subscription_query(subscription_query)
|
||||
|
||||
# Build WebSocket URL
|
||||
if not UNRAID_API_URL:
|
||||
raise ToolError("UNRAID_API_URL is not configured")
|
||||
ws_url = (
|
||||
UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://")
|
||||
+ "/graphql"
|
||||
)
|
||||
try:
|
||||
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription field '{field_name}'")
|
||||
|
||||
# Build WebSocket URL — raises ValueError on invalid/missing scheme (Bug 4 fix)
|
||||
try:
|
||||
ws_url = build_ws_url()
|
||||
except ValueError as e:
|
||||
raise ToolError(str(e)) from e
|
||||
|
||||
ssl_context = build_ws_ssl_context(ws_url)
|
||||
|
||||
@@ -59,6 +119,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
ws_url,
|
||||
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||
ssl=ssl_context,
|
||||
open_timeout=10,
|
||||
ping_interval=30,
|
||||
ping_timeout=10,
|
||||
) as websocket:
|
||||
@@ -102,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
"note": "Connection successful, subscription may be waiting for events",
|
||||
}
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
|
||||
return {"error": str(e), "query_tested": subscription_query}
|
||||
@@ -124,8 +187,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
# Get comprehensive status
|
||||
status = subscription_manager.get_subscription_status()
|
||||
|
||||
# Initialize connection issues list with proper type
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
# Analyze connection issues and error counts via shared helper.
|
||||
# Gates connection_issues on current failure state (Bug 5 fix).
|
||||
error_count, connection_issues = _analyze_subscription_status(status)
|
||||
|
||||
# Calculate WebSocket URL
|
||||
ws_url_display: str | None = None
|
||||
if UNRAID_API_URL:
|
||||
try:
|
||||
ws_url_display = build_ws_url()
|
||||
except ValueError:
|
||||
ws_url_display = None
|
||||
|
||||
# Add environment info with explicit typing
|
||||
diagnostic_info: dict[str, Any] = {
|
||||
@@ -135,7 +207,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
||||
"unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None,
|
||||
"api_key_configured": bool(UNRAID_API_KEY),
|
||||
"websocket_url": None,
|
||||
"websocket_url": ws_url_display,
|
||||
},
|
||||
"subscriptions": status,
|
||||
"summary": {
|
||||
@@ -147,40 +219,11 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
),
|
||||
"active_count": len(subscription_manager.active_subscriptions),
|
||||
"with_data": len(subscription_manager.resource_data),
|
||||
"in_error_state": 0,
|
||||
"in_error_state": error_count,
|
||||
"connection_issues": connection_issues,
|
||||
},
|
||||
}
|
||||
|
||||
# Calculate WebSocket URL
|
||||
if UNRAID_API_URL:
|
||||
if UNRAID_API_URL.startswith("https://"):
|
||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||
elif UNRAID_API_URL.startswith("http://"):
|
||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||
else:
|
||||
ws_url = UNRAID_API_URL
|
||||
if not ws_url.endswith("/graphql"):
|
||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||
diagnostic_info["environment"]["websocket_url"] = ws_url
|
||||
|
||||
# Analyze issues
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
connection_state = runtime.get("connection_state", "unknown")
|
||||
|
||||
if connection_state in ["error", "auth_failed", "timeout", "max_retries_exceeded"]:
|
||||
diagnostic_info["summary"]["in_error_state"] += 1
|
||||
|
||||
if runtime.get("last_error"):
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
"state": connection_state,
|
||||
"error": runtime["last_error"],
|
||||
}
|
||||
)
|
||||
|
||||
# Add troubleshooting recommendations
|
||||
recommendations: list[str] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user