refactor: comprehensive code review fixes across 31 files

Addresses all critical, high, medium, and low issues from full codebase
review. 494 tests pass, ruff clean, ty type-check clean.

Security:
- Add tool_error_handler context manager (exceptions.py) — standardised
  error handling, eliminates 11 bare except-reraise patterns
- Remove unused exception subclasses (ConfigurationError, UnraidAPIError,
  SubscriptionError, ValidationError, IdempotentOperationError)
- Harden GraphQL subscription query validator with allow-list and
  forbidden-keyword regex (diagnostics.py)
- Add input validation for rclone create_remote config_data: injection,
  path-traversal, and key-count limits (rclone.py)
- Validate notifications importance enum before GraphQL request (notifications.py)
- Sanitise HTTP/network/JSON error messages — no raw exception strings
  leaked to clients (client.py)
- Strip path/creds from displayed API URL via _safe_display_url (health.py)
- Enable Ruff S (bandit) rule category in pyproject.toml
- Harden container mutations to strict-only matching — no fuzzy/substring
  for destructive operations (docker.py)

Performance:
- Token-bucket rate limiter (90 tokens, 9 req/s) with 429 retry backoff (client.py)
- Lazy asyncio.Lock init via _get_client_lock() — fixes event-loop
  module-load crash (client.py)
- Double-checked locking in get_http_client() for fast-path (client.py)
- Short hex container ID fast-path skips list fetch (docker.py)
- Cap resource_data log content to 1 MB / 5,000 lines (manager.py)
- Reset reconnect counter after 30 s stable connection (manager.py)
- Move tail_lines validation to module level; enforce 10,000 line cap
  (storage.py, docker.py)
- force_terminal=True removed from logging RichHandler (logging.py)

Architecture:
- Register diagnostic tools in server startup (server.py)
- Move ALL_ACTIONS computation to module level in all tools
- Consolidate format_kb / format_bytes into shared core/utils.py
- Add _safe_get() helper in core/utils.py for nested dict traversal
- Extract _analyze_subscription_status() from health.py diagnose handler
- Validate required config at startup — fail fast with CRITICAL log (server.py)

Code quality:
- Remove ~90 lines of dead Rich formatting helpers from logging.py
- Remove dead self.websocket attribute from SubscriptionManager
- Remove dead setup_uvicorn_logging() wrapper
- Move _VALID_IMPORTANCE to module level (N806 fix)
- Add slots=True to all three dataclasses (SubscriptionData, SystemHealth, APIResponse)
- Fix None rendering as literal "None" string in info.py summaries
- Change fuzzy-match log messages from INFO to DEBUG (docker.py)
- UTC-aware datetimes throughout (manager.py, diagnostics.py)

Infrastructure:
- Upgrade base image python:3.11-slim → python:3.12-slim (Dockerfile)
- Add non-root appuser (UID/GID 1000) with HEALTHCHECK (Dockerfile)
- Add read_only, cap_drop: ALL, tmpfs /tmp to docker-compose.yml
- Single-source version via importlib.metadata (pyproject.toml → __init__.py)
- Add open_timeout to all websockets.connect() calls

Tests:
- Update error message matchers to match sanitised messages (test_client.py)
- Fix patch targets for UNRAID_API_URL → utils module (test_subscriptions.py)
- Fix importance="info" → importance="normal" (test_notifications.py, http_layer)
- Fix naive datetime fixtures → UTC-aware (test_subscriptions.py)

Co-authored-by: Claude <claude@anthropic.com>
This commit is contained in:
Jacob Magar
2026-02-18 01:02:13 -05:00
parent 5b6a728f45
commit 316193c04b
32 changed files with 995 additions and 622 deletions

View File

@@ -5,7 +5,9 @@ to the Unraid API with proper timeout handling and error management.
"""
import asyncio
import hashlib
import json
import time
from typing import Any
import httpx
@@ -22,7 +24,19 @@ from ..core.exceptions import ToolError
# Sensitive keys to redact from debug logs
_SENSITIVE_KEYS = {"password", "key", "secret", "token", "apikey"}
_SENSITIVE_KEYS = {
"password",
"key",
"secret",
"token",
"apikey",
"authorization",
"cookie",
"session",
"credential",
"passphrase",
"jwt",
}
def _is_sensitive_key(key: str) -> bool:
@@ -67,7 +81,121 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout:
# Global connection pool (module-level singleton)
_http_client: httpx.AsyncClient | None = None
_client_lock = asyncio.Lock()
_client_lock: asyncio.Lock | None = None
def _get_client_lock() -> asyncio.Lock:
"""Get or create the client lock (lazy init to avoid event loop issues)."""
global _client_lock
if _client_lock is None:
_client_lock = asyncio.Lock()
return _client_lock
class _RateLimiter:
"""Token bucket rate limiter for Unraid API (100 req / 10s hard limit).
Uses 90 tokens with 9.0 tokens/sec refill for 10% safety headroom.
"""
def __init__(self, max_tokens: int = 90, refill_rate: float = 9.0) -> None:
self.max_tokens = max_tokens
self.tokens = float(max_tokens)
self.refill_rate = refill_rate # tokens per second
self.last_refill = time.monotonic()
self._lock: asyncio.Lock | None = None
def _get_lock(self) -> asyncio.Lock:
if self._lock is None:
self._lock = asyncio.Lock()
return self._lock
def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
now = time.monotonic()
elapsed = now - self.last_refill
self.tokens = min(self.max_tokens, self.tokens + elapsed * self.refill_rate)
self.last_refill = now
async def acquire(self) -> None:
"""Consume one token, waiting if necessary for refill."""
while True:
async with self._get_lock():
self._refill()
if self.tokens >= 1:
self.tokens -= 1
return
wait_time = (1 - self.tokens) / self.refill_rate
# Sleep outside the lock so other coroutines aren't blocked
await asyncio.sleep(wait_time)
_rate_limiter = _RateLimiter()
# --- TTL Cache for stable read-only queries ---
# Queries whose results change infrequently and are safe to cache.
# Mutations and volatile queries (metrics, docker, array state) are excluded.
_CACHEABLE_QUERY_PREFIXES = frozenset(
{
"GetNetworkConfig",
"GetRegistrationInfo",
"GetOwner",
"GetFlash",
}
)
_CACHE_TTL_SECONDS = 60.0
class _QueryCache:
"""Simple TTL cache for GraphQL query responses.
Keyed by a hash of (query, variables). Entries expire after _CACHE_TTL_SECONDS.
Only caches responses for queries whose operation name is in _CACHEABLE_QUERY_PREFIXES.
Mutation requests always bypass the cache.
"""
def __init__(self) -> None:
self._store: dict[str, tuple[float, dict[str, Any]]] = {}
@staticmethod
def _cache_key(query: str, variables: dict[str, Any] | None) -> str:
raw = query + json.dumps(variables or {}, sort_keys=True)
return hashlib.sha256(raw.encode()).hexdigest()
@staticmethod
def is_cacheable(query: str) -> bool:
"""Check if a query is eligible for caching based on its operation name."""
if query.lstrip().startswith("mutation"):
return False
return any(prefix in query for prefix in _CACHEABLE_QUERY_PREFIXES)
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
"""Return cached result if present and not expired, else None."""
key = self._cache_key(query, variables)
entry = self._store.get(key)
if entry is None:
return None
expires_at, data = entry
if time.monotonic() > expires_at:
del self._store[key]
return None
return data
def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
"""Store a query result with TTL expiry."""
key = self._cache_key(query, variables)
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
def invalidate_all(self) -> None:
"""Clear the entire cache (called after mutations)."""
self._store.clear()
_query_cache = _QueryCache()
def is_idempotent_error(error_message: str, operation: str) -> bool:
@@ -109,7 +237,7 @@ async def _create_http_client() -> httpx.AsyncClient:
return httpx.AsyncClient(
# Connection pool settings
limits=httpx.Limits(
max_keepalive_connections=20, max_connections=100, keepalive_expiry=30.0
max_keepalive_connections=20, max_connections=20, keepalive_expiry=30.0
),
# Default timeout (can be overridden per-request)
timeout=DEFAULT_TIMEOUT,
@@ -123,40 +251,35 @@ async def _create_http_client() -> httpx.AsyncClient:
async def get_http_client() -> httpx.AsyncClient:
"""Get or create shared HTTP client with connection pooling.
The client is protected by an asyncio lock to prevent concurrent creation.
If the existing client was closed (e.g., during shutdown), a new one is created.
Uses double-checked locking: fast-path skips the lock when the client
is already initialized, only acquiring it for initial creation or
recovery after close.
Returns:
Singleton AsyncClient instance with connection pooling enabled
"""
global _http_client
async with _client_lock:
# Fast-path: skip lock if client is already initialized and open
client = _http_client
if client is not None and not client.is_closed:
return client
# Slow-path: acquire lock for initialization
async with _get_client_lock():
if _http_client is None or _http_client.is_closed:
_http_client = await _create_http_client()
logger.info(
"Created shared HTTP client with connection pooling (20 keepalive, 100 max connections)"
"Created shared HTTP client with connection pooling (20 keepalive, 20 max connections)"
)
client = _http_client
# Verify client is still open after releasing the lock.
# In asyncio's cooperative model this is unlikely to fail, but guards
# against edge cases where close_http_client runs between yield points.
if client.is_closed:
async with _client_lock:
_http_client = await _create_http_client()
client = _http_client
logger.info("Re-created HTTP client after unexpected close")
return client
return _http_client
async def close_http_client() -> None:
"""Close the shared HTTP client (call on server shutdown)."""
global _http_client
async with _client_lock:
async with _get_client_lock():
if _http_client is not None:
await _http_client.aclose()
_http_client = None
@@ -190,6 +313,14 @@ async def make_graphql_request(
if not UNRAID_API_KEY:
raise ToolError("UNRAID_API_KEY not configured")
# Check TTL cache for stable read-only queries
is_mutation = query.lstrip().startswith("mutation")
if not is_mutation and _query_cache.is_cacheable(query):
cached = _query_cache.get(query, variables)
if cached is not None:
logger.debug("Returning cached response for query")
return cached
headers = {
"Content-Type": "application/json",
"X-API-Key": UNRAID_API_KEY,
@@ -205,17 +336,31 @@ async def make_graphql_request(
logger.debug(f"Variables: {_redact_sensitive(variables)}")
try:
# Rate limit: consume a token before making the request
await _rate_limiter.acquire()
# Get the shared HTTP client with connection pooling
client = await get_http_client()
# Override timeout if custom timeout specified
# Retry loop for 429 rate limit responses
post_kwargs: dict[str, Any] = {"json": payload, "headers": headers}
if custom_timeout is not None:
response = await client.post(
UNRAID_API_URL, json=payload, headers=headers, timeout=custom_timeout
)
else:
response = await client.post(UNRAID_API_URL, json=payload, headers=headers)
post_kwargs["timeout"] = custom_timeout
response: httpx.Response | None = None
for attempt in range(3):
response = await client.post(UNRAID_API_URL, **post_kwargs)
if response.status_code == 429:
backoff = 2**attempt
logger.warning(
f"Rate limited (429) by Unraid API, retrying in {backoff}s (attempt {attempt + 1}/3)"
)
await asyncio.sleep(backoff)
continue
break
if response is None: # pragma: no cover — guaranteed by loop
raise ToolError("No response received after retry attempts")
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
response_data = response.json()
@@ -245,14 +390,27 @@ async def make_graphql_request(
logger.debug("GraphQL request successful.")
data = response_data.get("data", {})
return data if isinstance(data, dict) else {} # Ensure we return dict
result = data if isinstance(data, dict) else {} # Ensure we return dict
# Invalidate cache on mutations; cache eligible query results
if is_mutation:
_query_cache.invalidate_all()
elif _query_cache.is_cacheable(query):
_query_cache.put(query, variables, result)
return result
except httpx.HTTPStatusError as e:
# Log full details internally; only expose status code to MCP client
logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
raise ToolError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
raise ToolError(
f"Unraid API returned HTTP {e.response.status_code}. Check server logs for details."
) from e
except httpx.RequestError as e:
# Log full error internally; give safe summary to MCP client
logger.error(f"Request error occurred: {e}")
raise ToolError(f"Network connection error: {e!s}") from e
raise ToolError(f"Network error connecting to Unraid API: {type(e).__name__}") from e
except json.JSONDecodeError as e:
# Log full decode error; give safe summary to MCP client
logger.error(f"Failed to decode JSON response: {e}")
raise ToolError(f"Invalid JSON response from Unraid API: {e!s}") from e
raise ToolError("Unraid API returned an invalid response (not valid JSON)") from e

View File

@@ -4,6 +4,10 @@ This module defines custom exception classes for consistent error handling
throughout the application, with proper integration to FastMCP's error system.
"""
import contextlib
import logging
from collections.abc import Generator
from fastmcp.exceptions import ToolError as FastMCPToolError
@@ -19,36 +23,26 @@ class ToolError(FastMCPToolError):
pass
class ConfigurationError(ToolError):
"""Raised when there are configuration-related errors."""
@contextlib.contextmanager
def tool_error_handler(
tool_name: str,
action: str,
logger: logging.Logger,
) -> Generator[None]:
"""Context manager that standardizes tool error handling.
pass
Re-raises ToolError as-is. Catches all other exceptions, logs them
with full traceback, and wraps them in ToolError with a descriptive message.
class UnraidAPIError(ToolError):
"""Raised when the Unraid API returns an error or is unreachable."""
pass
class SubscriptionError(ToolError):
"""Raised when there are WebSocket subscription-related errors."""
pass
class ValidationError(ToolError):
"""Raised when input validation fails."""
pass
class IdempotentOperationError(ToolError):
"""Raised when an operation is idempotent (already in desired state).
This is used internally to signal that an operation was already complete,
which should typically be converted to a success response rather than
propagated as an error to the user.
Args:
tool_name: The tool name for error messages (e.g., "docker", "vm").
action: The current action being executed.
logger: The logger instance to use for error logging.
"""
pass
try:
yield
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e

View File

@@ -9,27 +9,33 @@ from datetime import datetime
from typing import Any
@dataclass
@dataclass(slots=True)
class SubscriptionData:
"""Container for subscription data with metadata."""
"""Container for subscription data with metadata.
Note: last_updated must be timezone-aware (use datetime.now(UTC)).
"""
data: dict[str, Any]
last_updated: datetime
last_updated: datetime # Must be timezone-aware (UTC)
subscription_type: str
@dataclass
@dataclass(slots=True)
class SystemHealth:
"""Container for system health status information."""
"""Container for system health status information.
Note: last_checked must be timezone-aware (use datetime.now(UTC)).
"""
is_healthy: bool
issues: list[str]
warnings: list[str]
last_checked: datetime
last_checked: datetime # Must be timezone-aware (UTC)
component_status: dict[str, str]
@dataclass
@dataclass(slots=True)
class APIResponse:
"""Container for standardized API response data."""

68
unraid_mcp/core/utils.py Normal file
View File

@@ -0,0 +1,68 @@
"""Shared utility functions for Unraid MCP tools."""
from typing import Any
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
"""Safely traverse nested dict keys, handling None intermediates.
Args:
data: The root dictionary to traverse.
*keys: Sequence of keys to follow.
default: Value to return if any key is missing or None.
Returns:
The value at the end of the key chain, or default if unreachable.
"""
current = data
for key in keys:
if not isinstance(current, dict):
return default
current = current.get(key)
return current if current is not None else default
def format_bytes(bytes_value: int | None) -> str:
"""Format byte values into human-readable sizes.
Args:
bytes_value: Number of bytes, or None.
Returns:
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
"""
if bytes_value is None:
return "N/A"
try:
value = float(int(bytes_value))
except (ValueError, TypeError):
return "N/A"
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if value < 1024.0:
return f"{value:.2f} {unit}"
value /= 1024.0
return f"{value:.2f} EB"
def format_kb(k: Any) -> str:
"""Format kilobyte values into human-readable sizes.
Args:
k: Number of kilobytes, or None.
Returns:
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
"""
if k is None:
return "N/A"
try:
k = int(k)
except (ValueError, TypeError):
return "N/A"
if k >= 1024 * 1024 * 1024:
return f"{k / (1024 * 1024 * 1024):.2f} TB"
if k >= 1024 * 1024:
return f"{k / (1024 * 1024):.2f} GB"
if k >= 1024:
return f"{k / 1024:.2f} MB"
return f"{k} KB"