refactor: simplify path validation and connection_init via shared helpers

- Extract _validate_path() in unraid.py — consolidates traversal check + normpath
  + prefix validation used by disk/logs and live/log_tail into one place
- Extract build_connection_init() in subscriptions/utils.py — removes 4 duplicate
  connection_init payload blocks from snapshot.py (×2), manager.py, diagnostics.py;
  also fixes diagnostics.py bug where x-api-key: None was sent when no key configured
- Remove _LIVE_ALLOWED_LOG_PREFIXES alias — direct reference to _ALLOWED_LOG_PREFIXES
- Move import hmac to module level in server.py (was inside verify_token hot path)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Jacob Magar
2026-03-23 11:57:00 -04:00
parent dc1e5f18d8
commit e68d4a80e4
7 changed files with 51 additions and 49 deletions

View File

@@ -4,6 +4,7 @@ This is the main server implementation using the modular architecture with
separate modules for configuration, core functionality, subscriptions, and tools.
"""
import hmac
import sys
from typing import Any
@@ -94,8 +95,6 @@ class ApiKeyVerifier(TokenVerifier):
self._api_key = api_key
async def verify_token(self, token: str) -> AccessToken | None:
import hmac
if self._api_key and hmac.compare_digest(token.encode(), self._api_key.encode()):
return AccessToken(
token=token,

View File

@@ -21,7 +21,12 @@ from ..core.exceptions import ToolError
from ..core.utils import safe_display_url
from .manager import subscription_manager
from .resources import ensure_subscriptions_started
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
from .utils import (
_analyze_subscription_status,
build_connection_init,
build_ws_ssl_context,
build_ws_url,
)
# Schema field names that appear inside the selection set of allowed subscriptions.
@@ -125,15 +130,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
ping_interval=30,
ping_timeout=10,
) as websocket:
# Send connection init (using standard X-API-Key format)
await websocket.send(
json.dumps(
{
"type": "connection_init",
"payload": {"x-api-key": _settings.UNRAID_API_KEY},
}
)
)
# Send connection init
await websocket.send(json.dumps(build_connection_init()))
# Wait for ack
response = await websocket.recv()

View File

@@ -19,7 +19,7 @@ from ..config import settings as _settings
from ..config.logging import logger
from ..core.client import redact_sensitive
from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context, build_ws_url
from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
# Resource data size limits to prevent unbounded memory growth
@@ -284,13 +284,9 @@ class SubscriptionManager:
logger.debug(
f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..."
)
init_type = "connection_init"
init_payload: dict[str, Any] = {"type": init_type}
if _settings.UNRAID_API_KEY:
init_payload = build_connection_init()
if "payload" in init_payload:
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
# Use graphql-ws connectionParams format (direct key, not nested headers)
init_payload["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
else:
logger.warning(
f"[AUTH:{subscription_name}] No API key available for authentication"

View File

@@ -18,10 +18,9 @@ from typing import Any
import websockets
from websockets.typing import Subprotocol
from ..config import settings as _settings
from ..config.logging import logger
from ..core.exceptions import ToolError
from .utils import build_ws_ssl_context, build_ws_url
from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
async def subscribe_once(
@@ -48,10 +47,7 @@ async def subscribe_once(
sub_id = "snapshot-1"
# Handshake
init: dict[str, Any] = {"type": "connection_init"}
if _settings.UNRAID_API_KEY:
init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
await ws.send(json.dumps(init))
await ws.send(json.dumps(build_connection_init()))
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
ack = json.loads(raw)
@@ -123,10 +119,7 @@ async def subscribe_collect(
proto = ws.subprotocol or "graphql-transport-ws"
sub_id = "snapshot-1"
init: dict[str, Any] = {"type": "connection_init"}
if _settings.UNRAID_API_KEY:
init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
await ws.send(json.dumps(init))
await ws.send(json.dumps(build_connection_init()))
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
ack = json.loads(raw)

View File

@@ -45,7 +45,7 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
ws_url: The WebSocket URL to connect to.
Returns:
An SSLContext configured per _settings.UNRAID_VERIFY_SSL, or None for non-TLS URLs.
An SSLContext configured per UNRAID_VERIFY_SSL, or None for non-TLS URLs.
"""
if not ws_url.startswith("wss://"):
return None
@@ -60,6 +60,18 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
return ctx
def build_connection_init() -> dict[str, Any]:
"""Build the graphql-ws connection_init message.
Omits the payload key entirely when no API key is configured —
sending {"x-api-key": None} and omitting the key differ for some servers.
"""
msg: dict[str, Any] = {"type": "connection_init"}
if _settings.UNRAID_API_KEY:
msg["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
return msg
def _analyze_subscription_status(
status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]:

View File

@@ -630,6 +630,23 @@ _ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/")
_MAX_TAIL_LINES = 10_000
def _validate_path(path: str, allowed_prefixes: tuple[str, ...], label: str) -> str:
"""Validate a remote path string for traversal and allowed prefix.
Uses pure string normalization — no filesystem access. The path is validated
locally but consumed on the remote Unraid server, so realpath would resolve
against the wrong filesystem.
Returns the normalized path. Raises ToolError on any violation.
"""
if ".." in path:
raise ToolError(f"{label} must not contain path traversal sequences (../)")
normalized = os.path.normpath(path)
if not any(normalized.startswith(p) for p in allowed_prefixes):
raise ToolError(f"{label} must start with one of: {', '.join(allowed_prefixes)}")
return normalized
async def _handle_disk(
subaction: str,
disk_id: str | None,
@@ -663,13 +680,7 @@ async def _handle_disk(
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
if not log_path:
raise ToolError("log_path is required for disk/logs")
# Validate without filesystem access — path is consumed on the remote Unraid server
if ".." in log_path:
raise ToolError("log_path must not contain path traversal sequences (../)")
normalized = os.path.normpath(log_path) # noqa: ASYNC240 — pure string normalization, no I/O
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
raise ToolError(f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}")
log_path = normalized
log_path = _validate_path(log_path, _ALLOWED_LOG_PREFIXES, "log_path")
if subaction == "flash_backup":
if not remote_name:
@@ -681,9 +692,10 @@ async def _handle_disk(
# Validate paths — flash backup source must come from /boot/ only
if ".." in source_path:
raise ToolError("source_path must not contain path traversal sequences (../)")
_norm_src = os.path.normpath(source_path) # noqa: ASYNC240 — pure string normalization, no I/O
if not (_norm_src == "/boot" or _norm_src.startswith("/boot/")):
normalized = os.path.normpath(source_path) # noqa: ASYNC240 — pure string, no I/O
if not (normalized == "/boot" or normalized.startswith("/boot/")):
raise ToolError("source_path must start with /boot/ (flash drive only)")
source_path = normalized
if ".." in destination_path:
raise ToolError("destination_path must not contain path traversal sequences (../)")
input_data: dict[str, Any] = {
@@ -1641,8 +1653,6 @@ async def _handle_user(subaction: str) -> dict[str, Any]:
# LIVE (subscriptions)
# ===========================================================================
_LIVE_ALLOWED_LOG_PREFIXES = _ALLOWED_LOG_PREFIXES
async def _handle_live(
subaction: str,
@@ -1662,13 +1672,7 @@ async def _handle_live(
if subaction == "log_tail":
if not path:
raise ToolError("path is required for live/log_tail")
# Validate without filesystem access — path is consumed on the remote Unraid server
if ".." in path:
raise ToolError("path must not contain path traversal sequences (../)")
normalized = os.path.normpath(path) # noqa: ASYNC240 — pure string normalization, no I/O
if not any(normalized.startswith(p) for p in _LIVE_ALLOWED_LOG_PREFIXES):
raise ToolError(f"path must start with one of: {', '.join(_LIVE_ALLOWED_LOG_PREFIXES)}")
path = normalized
path = _validate_path(path, _ALLOWED_LOG_PREFIXES, "path")
with tool_error_handler("live", subaction, logger):
logger.info(f"Executing unraid action=live subaction={subaction} timeout={timeout}")

2
uv.lock generated
View File

@@ -1572,7 +1572,7 @@ wheels = [
[[package]]
name = "unraid-mcp"
version = "1.1.1"
version = "1.1.2"
source = { editable = "." }
dependencies = [
{ name = "fastapi" },