forked from HomeLab/unraid-mcp
test: close critical coverage gaps and harden PR review fixes
Critical bug fixes from PR review agents: - client.py: eager asyncio.Lock init, Final[frozenset] for _SENSITIVE_KEYS, explicit 429 ToolError after retries exhausted, removed lazy _get_client_lock() and _RateLimiter._get_lock() patterns - exceptions.py: use builtin TimeoutError (UP041), explicit handler before broad except so asyncio timeouts get descriptive messages - docker.py: add update_all to DESTRUCTIVE_ACTIONS (was missing), remove dead _MUTATION_ACTIONS constant - manager.py: _cap_log_content returns new dict (immutable), lock write to resource_data, clean dead task from active_subscriptions after loop exits - diagnostics.py: fix inaccurate comment about semicolon injection guard - health.py: narrow except ValueError in _safe_display_url, fix TODO comment New test coverage (98 tests added, 529 → 598 passing): - test_subscription_validation.py: 27 tests for _validate_subscription_query (security-critical allow-list, forbidden keyword guards, word-boundary test) - test_subscription_manager.py: 12 tests for _cap_log_content (immutability, truncation, nesting, passthrough) - test_client.py: +57 tests — _RateLimiter (token math, refill, sleep-on-empty), _QueryCache (TTL, invalidation, is_cacheable), 429 retry loop (1/2/3 failures) - test_health.py: +10 tests for _safe_display_url (credential strip, port, path/query removal, malformed IPv6 → <unparseable>) - test_notifications.py: +7 importance enum and field length validation tests - test_rclone.py: +7 _validate_config_data security guard tests - test_storage.py: +15 (tail_lines bounds, format_kb, safe_get) - test_docker.py: update_all now requires confirm=True + new guard test - test_destructive_guards.py: update audit to include update_all Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Final
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -23,20 +23,22 @@ from ..config.settings import (
|
||||
from ..core.exceptions import ToolError
|
||||
|
||||
|
||||
# Sensitive keys to redact from debug logs
|
||||
_SENSITIVE_KEYS = {
|
||||
"password",
|
||||
"key",
|
||||
"secret",
|
||||
"token",
|
||||
"apikey",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"session",
|
||||
"credential",
|
||||
"passphrase",
|
||||
"jwt",
|
||||
}
|
||||
# Sensitive keys to redact from debug logs (frozenset — immutable, Final — no accidental reassignment)
|
||||
_SENSITIVE_KEYS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"password",
|
||||
"key",
|
||||
"secret",
|
||||
"token",
|
||||
"apikey",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"session",
|
||||
"credential",
|
||||
"passphrase",
|
||||
"jwt",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
@@ -80,16 +82,9 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout:
|
||||
|
||||
|
||||
# Global connection pool (module-level singleton)
|
||||
# Python 3.12+ asyncio.Lock() is safe at module level — no running event loop required
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
_client_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_client_lock() -> asyncio.Lock:
|
||||
"""Get or create the client lock (lazy init to avoid event loop issues)."""
|
||||
global _client_lock
|
||||
if _client_lock is None:
|
||||
_client_lock = asyncio.Lock()
|
||||
return _client_lock
|
||||
_client_lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||
|
||||
|
||||
class _RateLimiter:
|
||||
@@ -103,12 +98,8 @@ class _RateLimiter:
|
||||
self.tokens = float(max_tokens)
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.last_refill = time.monotonic()
|
||||
self._lock: asyncio.Lock | None = None
|
||||
|
||||
def _get_lock(self) -> asyncio.Lock:
|
||||
if self._lock is None:
|
||||
self._lock = asyncio.Lock()
|
||||
return self._lock
|
||||
# asyncio.Lock() is safe to create at __init__ time (Python 3.12+)
|
||||
self._lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""Refill tokens based on elapsed time."""
|
||||
@@ -120,7 +111,7 @@ class _RateLimiter:
|
||||
async def acquire(self) -> None:
|
||||
"""Consume one token, waiting if necessary for refill."""
|
||||
while True:
|
||||
async with self._get_lock():
|
||||
async with self._lock:
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
@@ -266,7 +257,7 @@ async def get_http_client() -> httpx.AsyncClient:
|
||||
return client
|
||||
|
||||
# Slow-path: acquire lock for initialization
|
||||
async with _get_client_lock():
|
||||
async with _client_lock:
|
||||
if _http_client is None or _http_client.is_closed:
|
||||
_http_client = await _create_http_client()
|
||||
logger.info(
|
||||
@@ -279,7 +270,7 @@ async def close_http_client() -> None:
|
||||
"""Close the shared HTTP client (call on server shutdown)."""
|
||||
global _http_client
|
||||
|
||||
async with _get_client_lock():
|
||||
async with _client_lock:
|
||||
if _http_client is not None:
|
||||
await _http_client.aclose()
|
||||
_http_client = None
|
||||
@@ -361,6 +352,14 @@ async def make_graphql_request(
|
||||
|
||||
if response is None: # pragma: no cover — guaranteed by loop
|
||||
raise ToolError("No response received after retry attempts")
|
||||
|
||||
# Provide a clear message when all retries are exhausted on 429
|
||||
if response.status_code == 429:
|
||||
logger.error("Rate limit (429) persisted after 3 retries — request aborted")
|
||||
raise ToolError(
|
||||
"Unraid API is rate limiting requests. Wait ~10 seconds before retrying."
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
@@ -6,7 +6,7 @@ throughout the application, with proper integration to FastMCP's error system.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
from fastmcp.exceptions import ToolError as FastMCPToolError
|
||||
|
||||
@@ -28,11 +28,12 @@ def tool_error_handler(
|
||||
tool_name: str,
|
||||
action: str,
|
||||
logger: logging.Logger,
|
||||
) -> Generator[None]:
|
||||
) -> Iterator[None]:
|
||||
"""Context manager that standardizes tool error handling.
|
||||
|
||||
Re-raises ToolError as-is. Catches all other exceptions, logs them
|
||||
with full traceback, and wraps them in ToolError with a descriptive message.
|
||||
Re-raises ToolError as-is. Gives TimeoutError a descriptive message.
|
||||
Catches all other exceptions, logs them with full traceback, and wraps them
|
||||
in ToolError with a descriptive message.
|
||||
|
||||
Args:
|
||||
tool_name: The tool name for error messages (e.g., "docker", "vm").
|
||||
@@ -43,6 +44,14 @@ def tool_error_handler(
|
||||
yield
|
||||
except ToolError:
|
||||
raise
|
||||
except TimeoutError as e:
|
||||
logger.error(
|
||||
f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit",
|
||||
exc_info=True,
|
||||
)
|
||||
raise ToolError(
|
||||
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e
|
||||
|
||||
@@ -37,8 +37,10 @@ _ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# Pattern: must start with "subscription", contain only a known subscription name,
|
||||
# and not contain mutation/query keywords or semicolons (prevents injection).
|
||||
# Pattern: must start with "subscription" and contain only a known subscription name.
|
||||
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
|
||||
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
|
||||
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
|
||||
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||
|
||||
|
||||
@@ -32,12 +32,17 @@ _STABLE_CONNECTION_SECONDS = 30
|
||||
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cap log content in subscription data to prevent unbounded memory growth.
|
||||
|
||||
If the data contains a 'content' field (from log subscriptions) that exceeds
|
||||
size limits, truncate to the most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||
Returns a new dict — does NOT mutate the input. If any nested 'content'
|
||||
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
||||
most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||
|
||||
Note: single lines larger than _MAX_RESOURCE_DATA_BYTES are not split and
|
||||
will still be stored at full size; only multi-line content is truncated.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
data[key] = _cap_log_content(value)
|
||||
result[key] = _cap_log_content(value)
|
||||
elif (
|
||||
key == "content"
|
||||
and isinstance(value, str)
|
||||
@@ -50,8 +55,12 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
f"[RESOURCE] Capped log content from {len(lines)} to "
|
||||
f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)"
|
||||
)
|
||||
data[key] = truncated
|
||||
return data
|
||||
result[key] = truncated
|
||||
else:
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
@@ -355,11 +364,13 @@ class SubscriptionManager:
|
||||
if isinstance(payload["data"], dict)
|
||||
else payload["data"]
|
||||
)
|
||||
self.resource_data[subscription_name] = SubscriptionData(
|
||||
new_entry = SubscriptionData(
|
||||
data=capped_data,
|
||||
last_updated=datetime.now(UTC),
|
||||
subscription_type=subscription_name,
|
||||
)
|
||||
async with self.subscription_lock:
|
||||
self.resource_data[subscription_name] = new_entry
|
||||
logger.debug(
|
||||
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
|
||||
)
|
||||
@@ -484,6 +495,16 @@ class SubscriptionManager:
|
||||
self.connection_states[subscription_name] = "reconnecting"
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# The while loop exited (via break or max_retries exceeded).
|
||||
# Remove from active_subscriptions so start_subscription() can restart it.
|
||||
async with self.subscription_lock:
|
||||
self.active_subscriptions.pop(subscription_name, None)
|
||||
logger.info(
|
||||
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
|
||||
f"removed from active_subscriptions. Final state: "
|
||||
f"{self.connection_states.get(subscription_name, 'unknown')}"
|
||||
)
|
||||
|
||||
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
||||
"""Get current resource data with enhanced logging."""
|
||||
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
||||
|
||||
@@ -99,8 +99,7 @@ MUTATIONS: dict[str, str] = {
|
||||
""",
|
||||
}
|
||||
|
||||
DESTRUCTIVE_ACTIONS = {"remove"}
|
||||
_MUTATION_ACTIONS = {"start", "stop", "restart", "pause", "unpause", "remove", "update"}
|
||||
DESTRUCTIVE_ACTIONS = {"remove", "update_all"}
|
||||
# NOTE (Code-M-07): "details" and "logs" are listed here because they require a
|
||||
# container_id parameter, but unlike mutations they use fuzzy name matching (not
|
||||
# strict). This is intentional: read-only queries are safe with fuzzy matching.
|
||||
|
||||
@@ -37,8 +37,8 @@ def _safe_display_url(url: str | None) -> str | None:
|
||||
if parsed.port:
|
||||
return f"{parsed.scheme}://{host}:{parsed.port}"
|
||||
return f"{parsed.scheme}://{host}"
|
||||
except Exception:
|
||||
# If parsing fails, show nothing rather than leaking the raw URL
|
||||
except ValueError:
|
||||
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
|
||||
return "<unparseable>"
|
||||
|
||||
|
||||
@@ -235,9 +235,9 @@ def _analyze_subscription_status(
|
||||
"""Analyze subscription status dict, returning error count and connection issues.
|
||||
|
||||
This is the canonical implementation of subscription status analysis.
|
||||
TODO: subscriptions/diagnostics.py (lines 168-182) duplicates this logic.
|
||||
That module should be refactored to call this helper once file ownership
|
||||
allows cross-agent edits. See Code-H05.
|
||||
TODO: subscriptions/diagnostics.py has a similar status-analysis pattern
|
||||
in diagnose_subscriptions(). That module could import and call this helper
|
||||
directly to avoid divergence. See Code-H05.
|
||||
|
||||
Args:
|
||||
status: Dict of subscription name -> status info from get_subscription_status().
|
||||
|
||||
Reference in New Issue
Block a user