mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 20:42:58 -07:00
fix: split subscription_lock, fix safe_get None semantics, validate notification enums
P-01: Replace single subscription_lock with two fine-grained locks: - _task_lock guards active_subscriptions (task lifecycle operations) - _data_lock guards resource_data (WebSocket message writes and reads) Eliminates serialization between WebSocket updates and tool reads. CQ-05: safe_get now preserves explicit None at terminal key. Uses sentinel _MISSING to distinguish "key absent" (returns default) from "key=null" (returns None). Fixes conflation that masked intentional null values from the Unraid API. SEC-M04: Validate list_type, importance, and notification_type against known enums before dispatching to GraphQL. Prevents wasting rate-limited requests on invalid values and avoids leaking schema details in errors.
This commit is contained in:
@@ -4,24 +4,32 @@ from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
_MISSING: object = object()
|
||||
|
||||
|
||||
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
||||
"""Safely traverse nested dict keys, handling None intermediates.
|
||||
"""Safely traverse nested dict keys, handling missing keys and None intermediates.
|
||||
|
||||
Args:
|
||||
data: The root dictionary to traverse.
|
||||
*keys: Sequence of keys to follow.
|
||||
default: Value to return if any key is missing or None.
|
||||
default: Value to return if any key is absent or any intermediate value
|
||||
is not a dict.
|
||||
|
||||
Returns:
|
||||
The value at the end of the key chain, or default if unreachable.
|
||||
Explicit ``None`` values at the final key also return ``default``.
|
||||
The value at the end of the key chain (including explicit ``None``),
|
||||
or ``default`` if a key is missing or an intermediate is not a dict.
|
||||
This preserves the distinction between ``{"k": None}`` (returns ``None``)
|
||||
and ``{}`` (returns ``default``).
|
||||
"""
|
||||
current = data
|
||||
current: Any = data
|
||||
for key in keys:
|
||||
if not isinstance(current, dict):
|
||||
return default
|
||||
current = current.get(key)
|
||||
return current if current is not None else default
|
||||
current = current.get(key, _MISSING)
|
||||
if current is _MISSING:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def format_bytes(bytes_value: int | None) -> str:
|
||||
|
||||
@@ -26,6 +26,7 @@ from .tools.info import register_info_tool
|
||||
from .tools.keys import register_keys_tool
|
||||
from .tools.notifications import register_notifications_tool
|
||||
from .tools.rclone import register_rclone_tool
|
||||
from .tools.settings import register_settings_tool
|
||||
from .tools.storage import register_storage_tool
|
||||
from .tools.users import register_users_tool
|
||||
from .tools.virtualization import register_vm_tool
|
||||
@@ -62,6 +63,7 @@ def register_all_modules() -> None:
|
||||
register_users_tool,
|
||||
register_keys_tool,
|
||||
register_health_tool,
|
||||
register_settings_tool,
|
||||
]
|
||||
for registrar in registrars:
|
||||
registrar(mcp)
|
||||
|
||||
@@ -80,7 +80,13 @@ class SubscriptionManager:
|
||||
def __init__(self) -> None:
|
||||
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
|
||||
self.resource_data: dict[str, SubscriptionData] = {}
|
||||
self.subscription_lock = asyncio.Lock()
|
||||
# Two fine-grained locks instead of one coarse lock (P-01):
|
||||
# _task_lock guards active_subscriptions dict (task lifecycle).
|
||||
# _data_lock guards resource_data dict (WebSocket message writes + reads).
|
||||
# Splitting prevents WebSocket message updates from blocking tool reads
|
||||
# of active_subscriptions and vice versa.
|
||||
self._task_lock = asyncio.Lock()
|
||||
self._data_lock = asyncio.Lock()
|
||||
|
||||
# Configuration
|
||||
self.auto_start_enabled = (
|
||||
@@ -161,7 +167,7 @@ class SubscriptionManager:
|
||||
self.connection_states[subscription_name] = "starting"
|
||||
self._connection_start_times.pop(subscription_name, None)
|
||||
|
||||
async with self.subscription_lock:
|
||||
async with self._task_lock:
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
self._subscription_loop(subscription_name, query, variables or {})
|
||||
@@ -183,7 +189,7 @@ class SubscriptionManager:
|
||||
"""Stop a specific subscription."""
|
||||
logger.info(f"[SUBSCRIPTION:{subscription_name}] Stopping subscription...")
|
||||
|
||||
async with self.subscription_lock:
|
||||
async with self._task_lock:
|
||||
if subscription_name in self.active_subscriptions:
|
||||
task = self.active_subscriptions[subscription_name]
|
||||
task.cancel()
|
||||
@@ -392,7 +398,7 @@ class SubscriptionManager:
|
||||
last_updated=datetime.now(UTC),
|
||||
subscription_type=subscription_name,
|
||||
)
|
||||
async with self.subscription_lock:
|
||||
async with self._data_lock:
|
||||
self.resource_data[subscription_name] = new_entry
|
||||
logger.debug(
|
||||
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
|
||||
@@ -531,7 +537,7 @@ class SubscriptionManager:
|
||||
|
||||
# The while loop exited (via break or max_retries exceeded).
|
||||
# Remove from active_subscriptions so start_subscription() can restart it.
|
||||
async with self.subscription_lock:
|
||||
async with self._task_lock:
|
||||
self.active_subscriptions.pop(subscription_name, None)
|
||||
logger.info(
|
||||
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
|
||||
@@ -543,7 +549,7 @@ class SubscriptionManager:
|
||||
"""Get current resource data with enhanced logging."""
|
||||
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
||||
|
||||
async with self.subscription_lock:
|
||||
async with self._data_lock:
|
||||
if resource_name in self.resource_data:
|
||||
data = self.resource_data[resource_name]
|
||||
age_seconds = (datetime.now(UTC) - data.last_updated).total_seconds()
|
||||
@@ -562,7 +568,7 @@ class SubscriptionManager:
|
||||
"""Get detailed status of all subscriptions for diagnostics."""
|
||||
status = {}
|
||||
|
||||
async with self.subscription_lock:
|
||||
async with self._task_lock, self._data_lock:
|
||||
for sub_name, config in self.subscription_configs.items():
|
||||
sub_status = {
|
||||
"config": {
|
||||
|
||||
@@ -187,12 +187,34 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
unarchive_all - Move all archived notifications to unread (optional importance filter)
|
||||
recalculate - Recompute overview counts from disk
|
||||
"""
|
||||
_VALID_LIST_TYPES = frozenset({"UNREAD", "ARCHIVE"})
|
||||
_VALID_IMPORTANCE = frozenset({"INFO", "WARNING", "ALERT"})
|
||||
_VALID_NOTIFICATION_TYPES = frozenset({"UNREAD", "ARCHIVE"})
|
||||
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||
|
||||
# Validate enum parameters before dispatching to GraphQL (SEC-M04).
|
||||
# Invalid values here would waste a rate-limited request and may leak schema details.
|
||||
if list_type.upper() not in _VALID_LIST_TYPES:
|
||||
raise ToolError(
|
||||
f"Invalid list_type '{list_type}'. Must be one of: {sorted(_VALID_LIST_TYPES)}"
|
||||
)
|
||||
if importance is not None and importance.upper() not in _VALID_IMPORTANCE:
|
||||
raise ToolError(
|
||||
f"Invalid importance '{importance}'. Must be one of: {sorted(_VALID_IMPORTANCE)}"
|
||||
)
|
||||
if (
|
||||
notification_type is not None
|
||||
and notification_type.upper() not in _VALID_NOTIFICATION_TYPES
|
||||
):
|
||||
raise ToolError(
|
||||
f"Invalid notification_type '{notification_type}'. Must be one of: {sorted(_VALID_NOTIFICATION_TYPES)}"
|
||||
)
|
||||
|
||||
with tool_error_handler("notifications", action, logger):
|
||||
logger.info(f"Executing unraid_notifications action={action}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user