refactor: simplify path validation and connection_init via shared helpers

- Extract _validate_path() in unraid.py — consolidates traversal check + normpath
  + prefix validation used by disk/logs and live/log_tail into one place
- Extract build_connection_init() in subscriptions/utils.py — removes 4 duplicate
  connection_init payload blocks from snapshot.py (×2), manager.py, diagnostics.py;
  also fixes diagnostics.py bug where x-api-key: None was sent when no key configured
- Remove _LIVE_ALLOWED_LOG_PREFIXES alias — direct reference to _ALLOWED_LOG_PREFIXES
- Move import hmac to module level in server.py (was inside verify_token hot path)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Jacob Magar
2026-03-23 11:57:00 -04:00
parent dc1e5f18d8
commit e68d4a80e4
7 changed files with 51 additions and 49 deletions

View File

@@ -4,6 +4,7 @@ This is the main server implementation using the modular architecture with
separate modules for configuration, core functionality, subscriptions, and tools. separate modules for configuration, core functionality, subscriptions, and tools.
""" """
import hmac
import sys import sys
from typing import Any from typing import Any
@@ -94,8 +95,6 @@ class ApiKeyVerifier(TokenVerifier):
self._api_key = api_key self._api_key = api_key
async def verify_token(self, token: str) -> AccessToken | None: async def verify_token(self, token: str) -> AccessToken | None:
import hmac
if self._api_key and hmac.compare_digest(token.encode(), self._api_key.encode()): if self._api_key and hmac.compare_digest(token.encode(), self._api_key.encode()):
return AccessToken( return AccessToken(
token=token, token=token,

View File

@@ -21,7 +21,12 @@ from ..core.exceptions import ToolError
from ..core.utils import safe_display_url from ..core.utils import safe_display_url
from .manager import subscription_manager from .manager import subscription_manager
from .resources import ensure_subscriptions_started from .resources import ensure_subscriptions_started
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url from .utils import (
_analyze_subscription_status,
build_connection_init,
build_ws_ssl_context,
build_ws_url,
)
# Schema field names that appear inside the selection set of allowed subscriptions. # Schema field names that appear inside the selection set of allowed subscriptions.
@@ -125,15 +130,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
ping_interval=30, ping_interval=30,
ping_timeout=10, ping_timeout=10,
) as websocket: ) as websocket:
# Send connection init (using standard X-API-Key format) # Send connection init
await websocket.send( await websocket.send(json.dumps(build_connection_init()))
json.dumps(
{
"type": "connection_init",
"payload": {"x-api-key": _settings.UNRAID_API_KEY},
}
)
)
# Wait for ack # Wait for ack
response = await websocket.recv() response = await websocket.recv()

View File

@@ -19,7 +19,7 @@ from ..config import settings as _settings
from ..config.logging import logger from ..config.logging import logger
from ..core.client import redact_sensitive from ..core.client import redact_sensitive
from ..core.types import SubscriptionData from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context, build_ws_url from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
# Resource data size limits to prevent unbounded memory growth # Resource data size limits to prevent unbounded memory growth
@@ -284,13 +284,9 @@ class SubscriptionManager:
logger.debug( logger.debug(
f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..." f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..."
) )
init_type = "connection_init" init_payload = build_connection_init()
init_payload: dict[str, Any] = {"type": init_type} if "payload" in init_payload:
if _settings.UNRAID_API_KEY:
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload") logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
# Use graphql-ws connectionParams format (direct key, not nested headers)
init_payload["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
else: else:
logger.warning( logger.warning(
f"[AUTH:{subscription_name}] No API key available for authentication" f"[AUTH:{subscription_name}] No API key available for authentication"

View File

@@ -18,10 +18,9 @@ from typing import Any
import websockets import websockets
from websockets.typing import Subprotocol from websockets.typing import Subprotocol
from ..config import settings as _settings
from ..config.logging import logger from ..config.logging import logger
from ..core.exceptions import ToolError from ..core.exceptions import ToolError
from .utils import build_ws_ssl_context, build_ws_url from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
async def subscribe_once( async def subscribe_once(
@@ -48,10 +47,7 @@ async def subscribe_once(
sub_id = "snapshot-1" sub_id = "snapshot-1"
# Handshake # Handshake
init: dict[str, Any] = {"type": "connection_init"} await ws.send(json.dumps(build_connection_init()))
if _settings.UNRAID_API_KEY:
init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
await ws.send(json.dumps(init))
raw = await asyncio.wait_for(ws.recv(), timeout=timeout) raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
ack = json.loads(raw) ack = json.loads(raw)
@@ -123,10 +119,7 @@ async def subscribe_collect(
proto = ws.subprotocol or "graphql-transport-ws" proto = ws.subprotocol or "graphql-transport-ws"
sub_id = "snapshot-1" sub_id = "snapshot-1"
init: dict[str, Any] = {"type": "connection_init"} await ws.send(json.dumps(build_connection_init()))
if _settings.UNRAID_API_KEY:
init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
await ws.send(json.dumps(init))
raw = await asyncio.wait_for(ws.recv(), timeout=timeout) raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
ack = json.loads(raw) ack = json.loads(raw)

View File

@@ -45,7 +45,7 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
ws_url: The WebSocket URL to connect to. ws_url: The WebSocket URL to connect to.
Returns: Returns:
An SSLContext configured per _settings.UNRAID_VERIFY_SSL, or None for non-TLS URLs. An SSLContext configured per UNRAID_VERIFY_SSL, or None for non-TLS URLs.
""" """
if not ws_url.startswith("wss://"): if not ws_url.startswith("wss://"):
return None return None
@@ -60,6 +60,18 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
return ctx return ctx
def build_connection_init() -> dict[str, Any]:
"""Build the graphql-ws connection_init message.
Omits the payload key entirely when no API key is configured —
sending {"x-api-key": None} and omitting the key differ for some servers.
"""
msg: dict[str, Any] = {"type": "connection_init"}
if _settings.UNRAID_API_KEY:
msg["payload"] = {"x-api-key": _settings.UNRAID_API_KEY}
return msg
def _analyze_subscription_status( def _analyze_subscription_status(
status: dict[str, Any], status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]: ) -> tuple[int, list[dict[str, Any]]]:

View File

@@ -630,6 +630,23 @@ _ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/")
_MAX_TAIL_LINES = 10_000 _MAX_TAIL_LINES = 10_000
def _validate_path(path: str, allowed_prefixes: tuple[str, ...], label: str) -> str:
"""Validate a remote path string for traversal and allowed prefix.
Uses pure string normalization — no filesystem access. The path is validated
locally but consumed on the remote Unraid server, so realpath would resolve
against the wrong filesystem.
Returns the normalized path. Raises ToolError on any violation.
"""
if ".." in path:
raise ToolError(f"{label} must not contain path traversal sequences (../)")
normalized = os.path.normpath(path)
if not any(normalized.startswith(p) for p in allowed_prefixes):
raise ToolError(f"{label} must start with one of: {', '.join(allowed_prefixes)}")
return normalized
async def _handle_disk( async def _handle_disk(
subaction: str, subaction: str,
disk_id: str | None, disk_id: str | None,
@@ -663,13 +680,7 @@ async def _handle_disk(
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}") raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
if not log_path: if not log_path:
raise ToolError("log_path is required for disk/logs") raise ToolError("log_path is required for disk/logs")
# Validate without filesystem access — path is consumed on the remote Unraid server log_path = _validate_path(log_path, _ALLOWED_LOG_PREFIXES, "log_path")
if ".." in log_path:
raise ToolError("log_path must not contain path traversal sequences (../)")
normalized = os.path.normpath(log_path) # noqa: ASYNC240 — pure string normalization, no I/O
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
raise ToolError(f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}")
log_path = normalized
if subaction == "flash_backup": if subaction == "flash_backup":
if not remote_name: if not remote_name:
@@ -681,9 +692,10 @@ async def _handle_disk(
# Validate paths — flash backup source must come from /boot/ only # Validate paths — flash backup source must come from /boot/ only
if ".." in source_path: if ".." in source_path:
raise ToolError("source_path must not contain path traversal sequences (../)") raise ToolError("source_path must not contain path traversal sequences (../)")
_norm_src = os.path.normpath(source_path) # noqa: ASYNC240 — pure string normalization, no I/O normalized = os.path.normpath(source_path) # noqa: ASYNC240 — pure string, no I/O
if not (_norm_src == "/boot" or _norm_src.startswith("/boot/")): if not (normalized == "/boot" or normalized.startswith("/boot/")):
raise ToolError("source_path must start with /boot/ (flash drive only)") raise ToolError("source_path must start with /boot/ (flash drive only)")
source_path = normalized
if ".." in destination_path: if ".." in destination_path:
raise ToolError("destination_path must not contain path traversal sequences (../)") raise ToolError("destination_path must not contain path traversal sequences (../)")
input_data: dict[str, Any] = { input_data: dict[str, Any] = {
@@ -1641,8 +1653,6 @@ async def _handle_user(subaction: str) -> dict[str, Any]:
# LIVE (subscriptions) # LIVE (subscriptions)
# =========================================================================== # ===========================================================================
_LIVE_ALLOWED_LOG_PREFIXES = _ALLOWED_LOG_PREFIXES
async def _handle_live( async def _handle_live(
subaction: str, subaction: str,
@@ -1662,13 +1672,7 @@ async def _handle_live(
if subaction == "log_tail": if subaction == "log_tail":
if not path: if not path:
raise ToolError("path is required for live/log_tail") raise ToolError("path is required for live/log_tail")
# Validate without filesystem access — path is consumed on the remote Unraid server path = _validate_path(path, _ALLOWED_LOG_PREFIXES, "path")
if ".." in path:
raise ToolError("path must not contain path traversal sequences (../)")
normalized = os.path.normpath(path) # noqa: ASYNC240 — pure string normalization, no I/O
if not any(normalized.startswith(p) for p in _LIVE_ALLOWED_LOG_PREFIXES):
raise ToolError(f"path must start with one of: {', '.join(_LIVE_ALLOWED_LOG_PREFIXES)}")
path = normalized
with tool_error_handler("live", subaction, logger): with tool_error_handler("live", subaction, logger):
logger.info(f"Executing unraid action=live subaction={subaction} timeout={timeout}") logger.info(f"Executing unraid action=live subaction={subaction} timeout={timeout}")

2
uv.lock generated
View File

@@ -1572,7 +1572,7 @@ wheels = [
[[package]] [[package]]
name = "unraid-mcp" name = "unraid-mcp"
version = "1.1.1" version = "1.1.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "fastapi" }, { name = "fastapi" },