mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 12:39:24 -07:00
refactor: simplify path validation and connection_init via shared helpers
- Extract _validate_path() in unraid.py — consolidates traversal check + normpath + prefix validation used by disk/logs and live/log_tail into one place - Extract build_connection_init() in subscriptions/utils.py — removes 4 duplicate connection_init payload blocks from snapshot.py (×2), manager.py, diagnostics.py; also fixes diagnostics.py bug where x-api-key: None was sent when no key configured - Remove _LIVE_ALLOWED_LOG_PREFIXES alias — direct reference to _ALLOWED_LOG_PREFIXES - Move import hmac to module level in server.py (was inside verify_token hot path) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ This is the main server implementation using the modular architecture with
|
||||
separate modules for configuration, core functionality, subscriptions, and tools.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
@@ -94,8 +95,6 @@ class ApiKeyVerifier(TokenVerifier):
|
||||
self._api_key = api_key
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
import hmac
|
||||
|
||||
if self._api_key and hmac.compare_digest(token.encode(), self._api_key.encode()):
|
||||
return AccessToken(
|
||||
token=token,
|
||||
|
||||
@@ -21,7 +21,12 @@ from ..core.exceptions import ToolError
|
||||
from ..core.utils import safe_display_url
|
||||
from .manager import subscription_manager
|
||||
from .resources import ensure_subscriptions_started
|
||||
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
|
||||
from .utils import (
|
||||
_analyze_subscription_status,
|
||||
build_connection_init,
|
||||
build_ws_ssl_context,
|
||||
build_ws_url,
|
||||
)
|
||||
|
||||
|
||||
# Schema field names that appear inside the selection set of allowed subscriptions.
|
||||
@@ -125,15 +130,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
ping_interval=30,
|
||||
ping_timeout=10,
|
||||
) as websocket:
|
||||
# Send connection init (using standard X-API-Key format)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "connection_init",
|
||||
"payload": {"x-api-key": _settings.UNRAID_API_KEY},
|
||||
}
|
||||
)
|
||||
)
|
||||
# Send connection init
|
||||
await websocket.send(json.dumps(build_connection_init()))
|
||||
|
||||
# Wait for ack
|
||||
response = await websocket.recv()
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..config import settings as _settings
|
||||
from ..config.logging import logger
|
||||
from ..core.client import redact_sensitive
|
||||
from ..core.types import SubscriptionData
|
||||
from .utils import build_ws_ssl_context, build_ws_url
|
||||
from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
# Resource data size limits to prevent unbounded memory growth
|
||||
@@ -284,13 +284,9 @@ class SubscriptionManager:
|
||||
logger.debug(
|
||||
f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..."
|
||||
)
|
||||
init_type = "connection_init"
|
||||
init_payload: dict[str, Any] = {"type": init_type}
|
||||
|
||||
if _settings.UNRAID_API_KEY:
|
||||
init_payload = build_connection_init()
|
||||
if "payload" in init_payload:
|
||||
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
|
||||
# Use graphql-ws connectionParams format (direct key, not nested headers)
|
||||
init_payload["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AUTH:{subscription_name}] No API key available for authentication"
|
||||
|
||||
@@ -18,10 +18,9 @@ from typing import Any
|
||||
import websockets
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from ..config import settings as _settings
|
||||
from ..config.logging import logger
|
||||
from ..core.exceptions import ToolError
|
||||
from .utils import build_ws_ssl_context, build_ws_url
|
||||
from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
async def subscribe_once(
|
||||
@@ -48,10 +47,7 @@ async def subscribe_once(
|
||||
sub_id = "snapshot-1"
|
||||
|
||||
# Handshake
|
||||
init: dict[str, Any] = {"type": "connection_init"}
|
||||
if _settings.UNRAID_API_KEY:
|
||||
init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
|
||||
await ws.send(json.dumps(init))
|
||||
await ws.send(json.dumps(build_connection_init()))
|
||||
|
||||
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||
ack = json.loads(raw)
|
||||
@@ -123,10 +119,7 @@ async def subscribe_collect(
|
||||
proto = ws.subprotocol or "graphql-transport-ws"
|
||||
sub_id = "snapshot-1"
|
||||
|
||||
init: dict[str, Any] = {"type": "connection_init"}
|
||||
if _settings.UNRAID_API_KEY:
|
||||
init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
|
||||
await ws.send(json.dumps(init))
|
||||
await ws.send(json.dumps(build_connection_init()))
|
||||
|
||||
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||
ack = json.loads(raw)
|
||||
|
||||
@@ -45,7 +45,7 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
|
||||
ws_url: The WebSocket URL to connect to.
|
||||
|
||||
Returns:
|
||||
An SSLContext configured per _settings.UNRAID_VERIFY_SSL, or None for non-TLS URLs.
|
||||
An SSLContext configured per UNRAID_VERIFY_SSL, or None for non-TLS URLs.
|
||||
"""
|
||||
if not ws_url.startswith("wss://"):
|
||||
return None
|
||||
@@ -60,6 +60,18 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
|
||||
return ctx
|
||||
|
||||
|
||||
def build_connection_init() -> dict[str, Any]:
|
||||
"""Build the graphql-ws connection_init message.
|
||||
|
||||
Omits the payload key entirely when no API key is configured —
|
||||
sending {"x-api-key": None} and omitting the key differ for some servers.
|
||||
"""
|
||||
msg: dict[str, Any] = {"type": "connection_init"}
|
||||
if _settings.UNRAID_API_KEY:
|
||||
msg["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
|
||||
return msg
|
||||
|
||||
|
||||
def _analyze_subscription_status(
|
||||
status: dict[str, Any],
|
||||
) -> tuple[int, list[dict[str, Any]]]:
|
||||
|
||||
@@ -630,6 +630,23 @@ _ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/")
|
||||
_MAX_TAIL_LINES = 10_000
|
||||
|
||||
|
||||
def _validate_path(path: str, allowed_prefixes: tuple[str, ...], label: str) -> str:
|
||||
"""Validate a remote path string for traversal and allowed prefix.
|
||||
|
||||
Uses pure string normalization — no filesystem access. The path is validated
|
||||
locally but consumed on the remote Unraid server, so realpath would resolve
|
||||
against the wrong filesystem.
|
||||
|
||||
Returns the normalized path. Raises ToolError on any violation.
|
||||
"""
|
||||
if ".." in path:
|
||||
raise ToolError(f"{label} must not contain path traversal sequences (../)")
|
||||
normalized = os.path.normpath(path)
|
||||
if not any(normalized.startswith(p) for p in allowed_prefixes):
|
||||
raise ToolError(f"{label} must start with one of: {', '.join(allowed_prefixes)}")
|
||||
return normalized
|
||||
|
||||
|
||||
async def _handle_disk(
|
||||
subaction: str,
|
||||
disk_id: str | None,
|
||||
@@ -663,13 +680,7 @@ async def _handle_disk(
|
||||
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||
if not log_path:
|
||||
raise ToolError("log_path is required for disk/logs")
|
||||
# Validate without filesystem access — path is consumed on the remote Unraid server
|
||||
if ".." in log_path:
|
||||
raise ToolError("log_path must not contain path traversal sequences (../)")
|
||||
normalized = os.path.normpath(log_path) # noqa: ASYNC240 — pure string normalization, no I/O
|
||||
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
|
||||
raise ToolError(f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}")
|
||||
log_path = normalized
|
||||
log_path = _validate_path(log_path, _ALLOWED_LOG_PREFIXES, "log_path")
|
||||
|
||||
if subaction == "flash_backup":
|
||||
if not remote_name:
|
||||
@@ -681,9 +692,10 @@ async def _handle_disk(
|
||||
# Validate paths — flash backup source must come from /boot/ only
|
||||
if ".." in source_path:
|
||||
raise ToolError("source_path must not contain path traversal sequences (../)")
|
||||
_norm_src = os.path.normpath(source_path) # noqa: ASYNC240 — pure string normalization, no I/O
|
||||
if not (_norm_src == "/boot" or _norm_src.startswith("/boot/")):
|
||||
normalized = os.path.normpath(source_path) # noqa: ASYNC240 — pure string, no I/O
|
||||
if not (normalized == "/boot" or normalized.startswith("/boot/")):
|
||||
raise ToolError("source_path must start with /boot/ (flash drive only)")
|
||||
source_path = normalized
|
||||
if ".." in destination_path:
|
||||
raise ToolError("destination_path must not contain path traversal sequences (../)")
|
||||
input_data: dict[str, Any] = {
|
||||
@@ -1641,8 +1653,6 @@ async def _handle_user(subaction: str) -> dict[str, Any]:
|
||||
# LIVE (subscriptions)
|
||||
# ===========================================================================
|
||||
|
||||
_LIVE_ALLOWED_LOG_PREFIXES = _ALLOWED_LOG_PREFIXES
|
||||
|
||||
|
||||
async def _handle_live(
|
||||
subaction: str,
|
||||
@@ -1662,13 +1672,7 @@ async def _handle_live(
|
||||
if subaction == "log_tail":
|
||||
if not path:
|
||||
raise ToolError("path is required for live/log_tail")
|
||||
# Validate without filesystem access — path is consumed on the remote Unraid server
|
||||
if ".." in path:
|
||||
raise ToolError("path must not contain path traversal sequences (../)")
|
||||
normalized = os.path.normpath(path) # noqa: ASYNC240 — pure string normalization, no I/O
|
||||
if not any(normalized.startswith(p) for p in _LIVE_ALLOWED_LOG_PREFIXES):
|
||||
raise ToolError(f"path must start with one of: {', '.join(_LIVE_ALLOWED_LOG_PREFIXES)}")
|
||||
path = normalized
|
||||
path = _validate_path(path, _ALLOWED_LOG_PREFIXES, "path")
|
||||
|
||||
with tool_error_handler("live", subaction, logger):
|
||||
logger.info(f"Executing unraid action=live subaction={subaction} timeout={timeout}")
|
||||
|
||||
Reference in New Issue
Block a user