mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 20:42:58 -07:00
refactor: simplify path validation and connection_init via shared helpers
- Extract _validate_path() in unraid.py — consolidates traversal check + normpath + prefix validation used by disk/logs and live/log_tail into one place - Extract build_connection_init() in subscriptions/utils.py — removes 4 duplicate connection_init payload blocks from snapshot.py (×2), manager.py, diagnostics.py; also fixes diagnostics.py bug where x-api-key: None was sent when no key configured - Remove _LIVE_ALLOWED_LOG_PREFIXES alias — direct reference to _ALLOWED_LOG_PREFIXES - Move import hmac to module level in server.py (was inside verify_token hot path) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -630,6 +630,23 @@ _ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/")
|
||||
_MAX_TAIL_LINES = 10_000
|
||||
|
||||
|
||||
def _validate_path(path: str, allowed_prefixes: tuple[str, ...], label: str) -> str:
|
||||
"""Validate a remote path string for traversal and allowed prefix.
|
||||
|
||||
Uses pure string normalization — no filesystem access. The path is validated
|
||||
locally but consumed on the remote Unraid server, so realpath would resolve
|
||||
against the wrong filesystem.
|
||||
|
||||
Returns the normalized path. Raises ToolError on any violation.
|
||||
"""
|
||||
if ".." in path:
|
||||
raise ToolError(f"{label} must not contain path traversal sequences (../)")
|
||||
normalized = os.path.normpath(path)
|
||||
if not any(normalized.startswith(p) for p in allowed_prefixes):
|
||||
raise ToolError(f"{label} must start with one of: {', '.join(allowed_prefixes)}")
|
||||
return normalized
|
||||
|
||||
|
||||
async def _handle_disk(
|
||||
subaction: str,
|
||||
disk_id: str | None,
|
||||
@@ -663,13 +680,7 @@ async def _handle_disk(
|
||||
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||
if not log_path:
|
||||
raise ToolError("log_path is required for disk/logs")
|
||||
# Validate without filesystem access — path is consumed on the remote Unraid server
|
||||
if ".." in log_path:
|
||||
raise ToolError("log_path must not contain path traversal sequences (../)")
|
||||
normalized = os.path.normpath(log_path) # noqa: ASYNC240 — pure string normalization, no I/O
|
||||
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
|
||||
raise ToolError(f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}")
|
||||
log_path = normalized
|
||||
log_path = _validate_path(log_path, _ALLOWED_LOG_PREFIXES, "log_path")
|
||||
|
||||
if subaction == "flash_backup":
|
||||
if not remote_name:
|
||||
@@ -681,9 +692,10 @@ async def _handle_disk(
|
||||
# Validate paths — flash backup source must come from /boot/ only
|
||||
if ".." in source_path:
|
||||
raise ToolError("source_path must not contain path traversal sequences (../)")
|
||||
_norm_src = os.path.normpath(source_path) # noqa: ASYNC240 — pure string normalization, no I/O
|
||||
if not (_norm_src == "/boot" or _norm_src.startswith("/boot/")):
|
||||
normalized = os.path.normpath(source_path) # noqa: ASYNC240 — pure string, no I/O
|
||||
if not (normalized == "/boot" or normalized.startswith("/boot/")):
|
||||
raise ToolError("source_path must start with /boot/ (flash drive only)")
|
||||
source_path = normalized
|
||||
if ".." in destination_path:
|
||||
raise ToolError("destination_path must not contain path traversal sequences (../)")
|
||||
input_data: dict[str, Any] = {
|
||||
@@ -1641,8 +1653,6 @@ async def _handle_user(subaction: str) -> dict[str, Any]:
|
||||
# LIVE (subscriptions)
|
||||
# ===========================================================================
|
||||
|
||||
_LIVE_ALLOWED_LOG_PREFIXES = _ALLOWED_LOG_PREFIXES
|
||||
|
||||
|
||||
async def _handle_live(
|
||||
subaction: str,
|
||||
@@ -1662,13 +1672,7 @@ async def _handle_live(
|
||||
if subaction == "log_tail":
|
||||
if not path:
|
||||
raise ToolError("path is required for live/log_tail")
|
||||
# Validate without filesystem access — path is consumed on the remote Unraid server
|
||||
if ".." in path:
|
||||
raise ToolError("path must not contain path traversal sequences (../)")
|
||||
normalized = os.path.normpath(path) # noqa: ASYNC240 — pure string normalization, no I/O
|
||||
if not any(normalized.startswith(p) for p in _LIVE_ALLOWED_LOG_PREFIXES):
|
||||
raise ToolError(f"path must start with one of: {', '.join(_LIVE_ALLOWED_LOG_PREFIXES)}")
|
||||
path = normalized
|
||||
path = _validate_path(path, _ALLOWED_LOG_PREFIXES, "path")
|
||||
|
||||
with tool_error_handler("live", subaction, logger):
|
||||
logger.info(f"Executing unraid action=live subaction={subaction} timeout={timeout}")
|
||||
|
||||
Reference in New Issue
Block a user