diff --git a/unraid_mcp/server.py b/unraid_mcp/server.py index 82dde31..5b41b49 100644 --- a/unraid_mcp/server.py +++ b/unraid_mcp/server.py @@ -4,6 +4,7 @@ This is the main server implementation using the modular architecture with separate modules for configuration, core functionality, subscriptions, and tools. """ +import hmac import sys from typing import Any @@ -94,8 +95,6 @@ class ApiKeyVerifier(TokenVerifier): self._api_key = api_key async def verify_token(self, token: str) -> AccessToken | None: - import hmac - if self._api_key and hmac.compare_digest(token.encode(), self._api_key.encode()): return AccessToken( token=token, diff --git a/unraid_mcp/subscriptions/diagnostics.py b/unraid_mcp/subscriptions/diagnostics.py index 3fd6d3d..82bb0d3 100644 --- a/unraid_mcp/subscriptions/diagnostics.py +++ b/unraid_mcp/subscriptions/diagnostics.py @@ -21,7 +21,12 @@ from ..core.exceptions import ToolError from ..core.utils import safe_display_url from .manager import subscription_manager from .resources import ensure_subscriptions_started -from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url +from .utils import ( + _analyze_subscription_status, + build_connection_init, + build_ws_ssl_context, + build_ws_url, +) # Schema field names that appear inside the selection set of allowed subscriptions. @@ -125,15 +130,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: ping_interval=30, ping_timeout=10, ) as websocket: - # Send connection init (using standard X-API-Key format) - await websocket.send( - json.dumps( - { - "type": "connection_init", - "payload": {"x-api-key": _settings.UNRAID_API_KEY}, - } - ) - ) + # Send connection init + await websocket.send(json.dumps(build_connection_init())) # Wait for ack response = await websocket.recv() diff --git a/unraid_mcp/subscriptions/manager.py b/unraid_mcp/subscriptions/manager.py index 93b2cee..c5d9bef 100644 --- a/unraid_mcp/subscriptions/manager.py +++ b/unraid_mcp/subscriptions/manager.py @@ -19,7 +19,7 @@ from ..config import settings as _settings from ..config.logging import logger from ..core.client import redact_sensitive from ..core.types import SubscriptionData -from .utils import build_ws_ssl_context, build_ws_url +from .utils import build_connection_init, build_ws_ssl_context, build_ws_url # Resource data size limits to prevent unbounded memory growth @@ -284,13 +284,9 @@ class SubscriptionManager: logger.debug( f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..." ) - init_type = "connection_init" - init_payload: dict[str, Any] = {"type": init_type} - - if _settings.UNRAID_API_KEY: + init_payload = build_connection_init() + if "payload" in init_payload: logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload") - # Use graphql-ws connectionParams format (direct key, not nested headers) - init_payload["payload"] = {"x-api-key": _settings.UNRAID_API_KEY} else: logger.warning( f"[AUTH:{subscription_name}] No API key available for authentication" diff --git a/unraid_mcp/subscriptions/snapshot.py b/unraid_mcp/subscriptions/snapshot.py index d0f6baa..068266a 100644 --- a/unraid_mcp/subscriptions/snapshot.py +++ b/unraid_mcp/subscriptions/snapshot.py @@ -18,10 +18,9 @@ from typing import Any import websockets from websockets.typing import Subprotocol -from ..config import settings as _settings from ..config.logging import logger from ..core.exceptions import ToolError -from .utils import build_ws_ssl_context, build_ws_url +from .utils import build_connection_init, build_ws_ssl_context, build_ws_url async def subscribe_once( @@ -48,10 +47,7 @@ async def subscribe_once( sub_id = "snapshot-1" # Handshake - init: dict[str, Any] = {"type": "connection_init"} - if _settings.UNRAID_API_KEY: - init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY} - await ws.send(json.dumps(init)) + await ws.send(json.dumps(build_connection_init())) raw = await asyncio.wait_for(ws.recv(), timeout=timeout) ack = json.loads(raw) @@ -123,10 +119,7 @@ async def subscribe_collect( proto = ws.subprotocol or "graphql-transport-ws" sub_id = "snapshot-1" - init: dict[str, Any] = {"type": "connection_init"} - if _settings.UNRAID_API_KEY: - init["payload"] = {"x-api-key": _settings.UNRAID_API_KEY} - await ws.send(json.dumps(init)) + await ws.send(json.dumps(build_connection_init())) raw = await asyncio.wait_for(ws.recv(), timeout=timeout) ack = json.loads(raw) diff --git a/unraid_mcp/subscriptions/utils.py b/unraid_mcp/subscriptions/utils.py index 6068611..71115be 100644 --- a/unraid_mcp/subscriptions/utils.py +++ b/unraid_mcp/subscriptions/utils.py @@ -45,7 +45,7 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None: ws_url: The WebSocket URL to connect to. Returns: - An SSLContext configured per _settings.UNRAID_VERIFY_SSL, or None for non-TLS URLs. + An SSLContext configured per UNRAID_VERIFY_SSL, or None for non-TLS URLs. """ if not ws_url.startswith("wss://"): return None @@ -60,6 +60,18 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None: return ctx +def build_connection_init() -> dict[str, Any]: + """Build the graphql-ws connection_init message. + + Omits the payload key entirely when no API key is configured — + sending {"x-api-key": None} and omitting the key differ for some servers. + """ + msg: dict[str, Any] = {"type": "connection_init"} + if _settings.UNRAID_API_KEY: + msg["payload"] = {"x-api-key": _settings.UNRAID_API_KEY} + return msg + + def _analyze_subscription_status( status: dict[str, Any], ) -> tuple[int, list[dict[str, Any]]]: diff --git a/unraid_mcp/tools/unraid.py b/unraid_mcp/tools/unraid.py index d96f821..04baf8a 100644 --- a/unraid_mcp/tools/unraid.py +++ b/unraid_mcp/tools/unraid.py @@ -630,6 +630,23 @@ _ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/") _MAX_TAIL_LINES = 10_000 +def _validate_path(path: str, allowed_prefixes: tuple[str, ...], label: str) -> str: + """Validate a remote path string for traversal and allowed prefix. + + Uses pure string normalization — no filesystem access. The path is validated + locally but consumed on the remote Unraid server, so realpath would resolve + against the wrong filesystem. + + Returns the normalized path. Raises ToolError on any violation. + """ + if ".." in path: + raise ToolError(f"{label} must not contain path traversal sequences (../)") + normalized = os.path.normpath(path) + if not any(normalized.startswith(p) for p in allowed_prefixes): + raise ToolError(f"{label} must start with one of: {', '.join(allowed_prefixes)}") + return normalized + + async def _handle_disk( subaction: str, disk_id: str | None, @@ -663,13 +680,7 @@ async def _handle_disk( raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}") if not log_path: raise ToolError("log_path is required for disk/logs") - # Validate without filesystem access — path is consumed on the remote Unraid server - if ".." in log_path: - raise ToolError("log_path must not contain path traversal sequences (../)") - normalized = os.path.normpath(log_path) # noqa: ASYNC240 — pure string normalization, no I/O - if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES): - raise ToolError(f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}") - log_path = normalized + log_path = _validate_path(log_path, _ALLOWED_LOG_PREFIXES, "log_path") if subaction == "flash_backup": if not remote_name: @@ -681,9 +692,10 @@ async def _handle_disk( # Validate paths — flash backup source must come from /boot/ only if ".." in source_path: raise ToolError("source_path must not contain path traversal sequences (../)") - _norm_src = os.path.normpath(source_path) # noqa: ASYNC240 — pure string normalization, no I/O - if not (_norm_src == "/boot" or _norm_src.startswith("/boot/")): + normalized = os.path.normpath(source_path) # noqa: ASYNC240 — pure string, no I/O + if not (normalized == "/boot" or normalized.startswith("/boot/")): raise ToolError("source_path must start with /boot/ (flash drive only)") + source_path = normalized if ".." in destination_path: raise ToolError("destination_path must not contain path traversal sequences (../)") input_data: dict[str, Any] = { @@ -1641,8 +1653,6 @@ async def _handle_user(subaction: str) -> dict[str, Any]: # LIVE (subscriptions) # =========================================================================== -_LIVE_ALLOWED_LOG_PREFIXES = _ALLOWED_LOG_PREFIXES - async def _handle_live( subaction: str, @@ -1662,13 +1672,7 @@ async def _handle_live( if subaction == "log_tail": if not path: raise ToolError("path is required for live/log_tail") - # Validate without filesystem access — path is consumed on the remote Unraid server - if ".." in path: - raise ToolError("path must not contain path traversal sequences (../)") - normalized = os.path.normpath(path) # noqa: ASYNC240 — pure string normalization, no I/O - if not any(normalized.startswith(p) for p in _LIVE_ALLOWED_LOG_PREFIXES): - raise ToolError(f"path must start with one of: {', '.join(_LIVE_ALLOWED_LOG_PREFIXES)}") - path = normalized + path = _validate_path(path, _ALLOWED_LOG_PREFIXES, "path") with tool_error_handler("live", subaction, logger): logger.info(f"Executing unraid action=live subaction={subaction} timeout={timeout}") diff --git a/uv.lock b/uv.lock index 817ba91..bf5d0d6 100644 --- a/uv.lock +++ b/uv.lock @@ -1572,7 +1572,7 @@ wheels = [ [[package]] name = "unraid-mcp" -version = "1.1.1" +version = "1.1.2" source = { editable = "." } dependencies = [ { name = "fastapi" },