feat: add 5 notification mutations + comprehensive refactors from PR review

New notification actions (archive_many, create_unique, unarchive_many,
unarchive_all, recalculate) bring unraid_notifications to 14 actions.

Also includes continuation of CodeRabbit/PR review fixes:
- Remove redundant try-except in virtualization.py (silent failure fix)
- Add QueryCache protocol with get/put/invalidate_all to core/client.py
- Refactor subscriptions (manager, diagnostics, resources, utils)
- Update config (logging, settings) for improved structure
- Expand test coverage: http_layer, safety guards, schema validation
- Minor cleanups: array, docker, health, keys tools

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jacob Magar
2026-03-13 01:54:55 -04:00
parent 06f18f32fc
commit 60defc35ca
27 changed files with 2508 additions and 423 deletions

View File

@@ -53,21 +53,22 @@ class OverwriteFileHandler(logging.FileHandler):
):
try:
base_path = Path(self.baseFilename)
if base_path.exists():
file_size = base_path.stat().st_size
if file_size >= self.max_bytes:
# Close current stream
if self.stream:
self.stream.close()
# Remove the old file and start fresh
if base_path.exists():
base_path.unlink()
# Reopen with truncate mode
file_size = base_path.stat().st_size if base_path.exists() else 0
if file_size >= self.max_bytes:
old_stream = self.stream
self.stream = None
try:
old_stream.close()
base_path.unlink(missing_ok=True)
self.stream = self._open()
except OSError:
# Recovery: attempt to reopen even if unlink failed
try:
self.stream = self._open()
except OSError:
self.stream = old_stream # Last resort: restore original
# Log a marker that the file was reset
if self.stream is not None:
reset_record = logging.LogRecord(
name="UnraidMCPServer.Logging",
level=logging.INFO,
@@ -184,27 +185,8 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
fastmcp_logger.setLevel(numeric_log_level)
# Also configure the root logger to catch any other logs
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.propagate = False
# Rich Console Handler for root logger
root_console_handler = RichHandler(
console=console,
show_time=True,
show_level=True,
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=False,
markup=True,
)
root_console_handler.setLevel(numeric_log_level)
root_logger.addHandler(root_console_handler)
# Reuse the shared file handler for root logger
root_logger.addHandler(_shared_file_handler)
root_logger.setLevel(numeric_log_level)
# Set root logger level to avoid suppressing library warnings entirely
logging.getLogger().setLevel(numeric_log_level)
return fastmcp_logger

View File

@@ -36,8 +36,27 @@ for dotenv_path in dotenv_paths:
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
# Server Configuration
UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970"))
def _parse_port(env_var: str, default: int) -> int:
"""Parse a port number from environment variable with validation."""
raw = os.getenv(env_var, str(default))
try:
port = int(raw)
except ValueError:
import sys
print(f"FATAL: {env_var}={raw!r} is not a valid integer port number", file=sys.stderr)
sys.exit(1)
if not (1 <= port <= 65535):
import sys
print(f"FATAL: {env_var}={port} outside valid port range 1-65535", file=sys.stderr)
sys.exit(1)
return port
UNRAID_MCP_PORT = _parse_port("UNRAID_MCP_PORT", 6970)
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0") # noqa: S104 — intentional for Docker
UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower()
@@ -58,7 +77,7 @@ IS_DOCKER = Path("/.dockerenv").exists()
LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs"
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
# Ensure logs directory exists; if creation fails, fall back to /tmp.
# Ensure logs directory exists; if creation fails, fall back to PROJECT_ROOT / ".cache" / "logs".
try:
LOGS_DIR.mkdir(parents=True, exist_ok=True)
except OSError:
@@ -97,9 +116,11 @@ def get_config_summary() -> dict[str, Any]:
"""
is_valid, missing = validate_required_config()
from ..core.utils import safe_display_url
return {
"api_url_configured": bool(UNRAID_API_URL),
"api_url_preview": UNRAID_API_URL[:20] + "..." if UNRAID_API_URL else None,
"api_url_preview": safe_display_url(UNRAID_API_URL) if UNRAID_API_URL else None,
"api_key_configured": bool(UNRAID_API_KEY),
"server_host": UNRAID_MCP_HOST,
"server_port": UNRAID_MCP_PORT,
@@ -110,5 +131,7 @@ def get_config_summary() -> dict[str, Any]:
"config_valid": is_valid,
"missing_config": missing if not is_valid else None,
}
# Re-export application version from a single source of truth.
VERSION = APP_VERSION

View File

@@ -51,9 +51,7 @@ def _is_sensitive_key(key: str) -> bool:
def redact_sensitive(obj: Any) -> Any:
"""Recursively redact sensitive values from nested dicts/lists."""
if isinstance(obj, dict):
return {
k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()
}
return {k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()}
if isinstance(obj, list):
return [redact_sensitive(item) for item in obj]
return obj
@@ -149,10 +147,16 @@ class _QueryCache:
Keyed by a hash of (query, variables). Entries expire after _CACHE_TTL_SECONDS.
Only caches responses for queries whose operation name is in _CACHEABLE_QUERY_PREFIXES.
Mutation requests always bypass the cache.
Thread-safe via asyncio.Lock. Bounded to _MAX_ENTRIES with FIFO eviction (oldest
expiry timestamp evicted first when the store is full).
"""
_MAX_ENTRIES: Final[int] = 256
def __init__(self) -> None:
self._store: dict[str, tuple[float, dict[str, Any]]] = {}
self._lock: Final[asyncio.Lock] = asyncio.Lock()
@staticmethod
def _cache_key(query: str, variables: dict[str, Any] | None) -> str:
@@ -170,26 +174,32 @@ class _QueryCache:
return False
return match.group(1) in _CACHEABLE_QUERY_PREFIXES
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
async def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
"""Return cached result if present and not expired, else None."""
key = self._cache_key(query, variables)
entry = self._store.get(key)
if entry is None:
return None
expires_at, data = entry
if time.monotonic() > expires_at:
del self._store[key]
return None
return data
async with self._lock:
key = self._cache_key(query, variables)
entry = self._store.get(key)
if entry is None:
return None
expires_at, data = entry
if time.monotonic() > expires_at:
del self._store[key]
return None
return data
def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
"""Store a query result with TTL expiry."""
key = self._cache_key(query, variables)
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
async def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
"""Store a query result with TTL expiry, evicting oldest entry if at capacity."""
async with self._lock:
if len(self._store) >= self._MAX_ENTRIES:
oldest_key = min(self._store, key=lambda k: self._store[k][0])
del self._store[oldest_key]
key = self._cache_key(query, variables)
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
def invalidate_all(self) -> None:
async def invalidate_all(self) -> None:
"""Clear the entire cache (called after mutations)."""
self._store.clear()
async with self._lock:
self._store.clear()
_query_cache = _QueryCache()
@@ -310,10 +320,10 @@ async def make_graphql_request(
if not UNRAID_API_KEY:
raise ToolError("UNRAID_API_KEY not configured")
# Check TTL cache for stable read-only queries
# Check TTL cache — short-circuits rate limiter on hits
is_mutation = query.lstrip().startswith("mutation")
if not is_mutation and _query_cache.is_cacheable(query):
cached = _query_cache.get(query, variables)
cached = await _query_cache.get(query, variables)
if cached is not None:
logger.debug("Returning cached response for query")
return cached
@@ -399,9 +409,9 @@ async def make_graphql_request(
# Invalidate cache on mutations; cache eligible query results
if is_mutation:
_query_cache.invalidate_all()
await _query_cache.invalidate_all()
elif _query_cache.is_cacheable(query):
_query_cache.put(query, variables, result)
await _query_cache.put(query, variables, result)
return result

View File

@@ -11,12 +11,19 @@ import sys
async def shutdown_cleanup() -> None:
"""Cleanup resources on server shutdown."""
try:
from .subscriptions.manager import subscription_manager
await subscription_manager.stop_all()
except Exception as e:
print(f"Error stopping subscriptions during cleanup: {e}", file=sys.stderr)
try:
from .core.client import close_http_client
await close_http_client()
except Exception as e:
print(f"Error during cleanup: {e}")
print(f"Error during cleanup: {e}", file=sys.stderr)
def _run_shutdown_cleanup() -> None:

View File

@@ -10,8 +10,6 @@ from fastmcp import FastMCP
from .config.logging import logger
from .config.settings import (
UNRAID_API_KEY,
UNRAID_API_URL,
UNRAID_MCP_HOST,
UNRAID_MCP_PORT,
UNRAID_MCP_TRANSPORT,
@@ -86,20 +84,10 @@ def run_server() -> None:
)
sys.exit(1)
# Log configuration
if UNRAID_API_URL:
logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...")
else:
logger.warning("UNRAID_API_URL not found in environment or .env file.")
# Log configuration (delegated to shared function)
from .config.logging import log_configuration_status
if UNRAID_API_KEY:
logger.info("UNRAID_API_KEY loaded: ****")
else:
logger.warning("UNRAID_API_KEY not found in environment or .env file.")
logger.info(f"UNRAID_MCP_PORT set to: {UNRAID_MCP_PORT}")
logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}")
logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}")
log_configuration_status(logger)
if UNRAID_VERIFY_SSL is False:
logger.warning(

View File

@@ -22,7 +22,7 @@ from ..core.exceptions import ToolError
from ..core.utils import safe_display_url
from .manager import subscription_manager
from .resources import ensure_subscriptions_started
from .utils import build_ws_ssl_context, build_ws_url
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
@@ -187,8 +187,10 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
# Get comprehensive status
status = await subscription_manager.get_subscription_status()
# Initialize connection issues list with proper type
connection_issues: list[dict[str, Any]] = []
# Analyze connection issues and error counts via the shared helper.
# This ensures "invalid_uri" and all other error states are counted
# consistently with the health tool's _diagnose_subscriptions path.
error_count, connection_issues = _analyze_subscription_status(status)
# Add environment info with explicit typing
diagnostic_info: dict[str, Any] = {
@@ -210,7 +212,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
),
"active_count": len(subscription_manager.active_subscriptions),
"with_data": len(subscription_manager.resource_data),
"in_error_state": 0,
"in_error_state": error_count,
"connection_issues": connection_issues,
},
}
@@ -219,23 +221,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
with contextlib.suppress(ValueError):
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
# Analyze issues
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
connection_state = runtime.get("connection_state", "unknown")
if connection_state in ["error", "auth_failed", "timeout", "max_retries_exceeded"]:
diagnostic_info["summary"]["in_error_state"] += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": connection_state,
"error": runtime["last_error"],
}
)
# Add troubleshooting recommendations
recommendations: list[str] = []

View File

@@ -45,7 +45,7 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
elif (
key == "content"
and isinstance(value, str)
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
and len(value) > _MAX_RESOURCE_DATA_BYTES # fast pre-check on char count
):
lines = value.splitlines()
original_line_count = len(lines)
@@ -54,19 +54,15 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
# Enforce byte cap while preserving whole-line boundaries where possible.
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
lines = lines[1:]
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
# Last resort: if a single line still exceeds cap, hard-cap bytes.
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
"utf-8", errors="ignore"
)
# Encode once and slice bytes instead of O(n²) line-trim loop
encoded = truncated.encode("utf-8", errors="replace")
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
# Strip partial first line that may have been cut mid-character
nl_pos = truncated.find("\n")
if nl_pos != -1:
truncated = truncated[nl_pos + 1 :]
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
@@ -202,6 +198,16 @@ class SubscriptionManager:
else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
async def stop_all(self) -> None:
"""Stop all active subscriptions (called during server shutdown)."""
subscription_names = list(self.active_subscriptions.keys())
for name in subscription_names:
try:
await self.stop_subscription(name)
except Exception as e:
logger.error(f"[SHUTDOWN] Error stopping subscription '{name}': {e}", exc_info=True)
logger.info(f"[SHUTDOWN] Stopped {len(subscription_names)} subscription(s)")
async def _subscription_loop(
self, subscription_name: str, query: str, variables: dict[str, Any] | None
) -> None:
@@ -512,9 +518,11 @@ class SubscriptionManager:
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
)
# Calculate backoff delay
retry_delay = min(retry_delay * 1.5, max_retry_delay)
# Only escalate backoff when connection was NOT stable
retry_delay = min(retry_delay * 1.5, max_retry_delay)
else:
# No connection was established — escalate backoff
retry_delay = min(retry_delay * 1.5, max_retry_delay)
logger.info(
f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds..."
)

View File

@@ -4,8 +4,10 @@ This module defines MCP resources that bridge between the subscription manager
and the MCP protocol, providing fallback queries when subscription data is unavailable.
"""
import asyncio
import json
import os
from typing import Final
import anyio
from fastmcp import FastMCP
@@ -16,22 +18,29 @@ from .manager import subscription_manager
# Global flag to track subscription startup
_subscriptions_started = False
_startup_lock: Final[asyncio.Lock] = asyncio.Lock()
async def ensure_subscriptions_started() -> None:
"""Ensure subscriptions are started, called from async context."""
global _subscriptions_started
# Fast-path: skip lock if already started
if _subscriptions_started:
return
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
try:
await autostart_subscriptions()
_subscriptions_started = True
logger.info("[STARTUP] Subscriptions started successfully")
except Exception as e:
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
# Slow-path: acquire lock for initialization (double-checked locking)
async with _startup_lock:
if _subscriptions_started:
return
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
try:
await autostart_subscriptions()
_subscriptions_started = True
logger.info("[STARTUP] Subscriptions started successfully")
except Exception as e:
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
async def autostart_subscriptions() -> None:

View File

@@ -1,6 +1,7 @@
"""Shared utilities for the subscription system."""
import ssl as _ssl
from typing import Any
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL
@@ -52,3 +53,37 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
ctx.check_hostname = False
ctx.verify_mode = _ssl.CERT_NONE
return ctx
def _analyze_subscription_status(
status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]:
"""Analyze subscription status dict, returning error count and connection issues.
This is the canonical, shared implementation used by both the health tool
and the subscription diagnostics tool.
Args:
status: Dict of subscription name -> status info from get_subscription_status().
Returns:
Tuple of (error_count, connection_issues_list).
"""
error_count = 0
connection_issues: list[dict[str, Any]] = []
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"):
error_count += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": conn_state,
"error": runtime["last_error"],
}
)
return error_count, connection_issues

View File

@@ -1,14 +1,14 @@
"""MCP tools organized by functional domain.
10 consolidated tools with ~90 actions total:
10 consolidated tools with 76 actions total:
unraid_info - System information queries (19 actions)
unraid_array - Array operations and power management (12 actions)
unraid_array - Array operations and parity management (5 actions)
unraid_storage - Storage, disks, and logs (6 actions)
unraid_docker - Docker container management (15 actions)
unraid_vm - Virtual machine management (9 actions)
unraid_notifications - Notification management (9 actions)
unraid_rclone - Cloud storage remotes (4 actions)
unraid_users - User management (8 actions)
unraid_users - User management (1 action)
unraid_keys - API key management (5 actions)
unraid_health - Health monitoring and diagnostics (3 actions)
"""

View File

@@ -73,7 +73,7 @@ def register_array_tool(mcp: FastMCP) -> None:
"""Manage Unraid array parity checks.
Actions:
parity_start - Start parity check (optional correct=True to fix errors)
parity_start - Start parity check (correct=True to fix errors, correct=False for read-only; required)
parity_pause - Pause running parity check
parity_resume - Resume paused parity check
parity_cancel - Cancel running parity check

View File

@@ -233,8 +233,8 @@ async def _resolve_container_id(container_id: str, *, strict: bool = False) -> s
data = await make_graphql_request(list_query)
containers = safe_get(data, "docker", "containers", default=[])
# Short hex prefix: match by ID prefix before trying name matching
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
# Short hex prefix: match by ID prefix before trying name matching (strict bypasses this)
if not strict and _DOCKER_SHORT_ID_PATTERN.match(container_id):
id_lower = container_id.lower()
matches: list[dict[str, Any]] = []
for c in containers:

View File

@@ -21,6 +21,7 @@ from ..config.settings import (
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler
from ..core.utils import safe_display_url
from ..subscriptions.utils import _analyze_subscription_status
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
@@ -218,42 +219,6 @@ async def _comprehensive_check() -> dict[str, Any]:
}
def _analyze_subscription_status(
status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]:
"""Analyze subscription status dict, returning error count and connection issues.
This is the canonical implementation of subscription status analysis.
TODO: subscriptions/diagnostics.py has a similar status-analysis pattern
in diagnose_subscriptions(). That module could import and call this helper
directly to avoid divergence. See Code-H05.
Args:
status: Dict of subscription name -> status info from get_subscription_status().
Returns:
Tuple of (error_count, connection_issues_list).
"""
error_count = 0
connection_issues: list[dict[str, Any]] = []
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
error_count += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": conn_state,
"error": runtime["last_error"],
}
)
return error_count, connection_issues
async def _diagnose_subscriptions() -> dict[str, Any]:
"""Import and run subscription diagnostics."""
try:

View File

@@ -114,10 +114,14 @@ def register_keys_tool(mcp: FastMCP) -> None:
if permissions is not None:
input_data["permissions"] = permissions
data = await make_graphql_request(MUTATIONS["create"], {"input": input_data})
return {
"success": True,
"key": (data.get("apiKey") or {}).get("create", {}),
}
created_key = (data.get("apiKey") or {}).get("create")
if not created_key:
return {
"success": False,
"key": {},
"message": "API key creation failed: no data returned from server",
}
return {"success": True, "key": created_key}
if action == "update":
if not key_id:
@@ -128,10 +132,14 @@ def register_keys_tool(mcp: FastMCP) -> None:
if roles is not None:
input_data["roles"] = roles
data = await make_graphql_request(MUTATIONS["update"], {"input": input_data})
return {
"success": True,
"key": (data.get("apiKey") or {}).get("update", {}),
}
updated_key = (data.get("apiKey") or {}).get("update")
if not updated_key:
return {
"success": False,
"key": {},
"message": "API key update failed: no data returned from server",
}
return {"success": True, "key": updated_key}
if action == "delete":
if not key_id:

View File

@@ -50,34 +50,80 @@ MUTATIONS: dict[str, str] = {
""",
"archive": """
mutation ArchiveNotification($id: PrefixedID!) {
archiveNotification(id: $id)
archiveNotification(id: $id) { id title importance }
}
""",
"unread": """
mutation UnreadNotification($id: PrefixedID!) {
unreadNotification(id: $id)
unreadNotification(id: $id) { id title importance }
}
""",
"delete": """
mutation DeleteNotification($id: PrefixedID!, $type: NotificationType!) {
deleteNotification(id: $id, type: $type)
deleteNotification(id: $id, type: $type) {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
"delete_archived": """
mutation DeleteArchivedNotifications {
deleteArchivedNotifications
deleteArchivedNotifications {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
"archive_all": """
mutation ArchiveAllNotifications($importance: NotificationImportance) {
archiveAll(importance: $importance)
archiveAll(importance: $importance) {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
"archive_many": """
mutation ArchiveNotifications($ids: [PrefixedID!]!) {
archiveNotifications(ids: $ids) {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
"create_unique": """
mutation NotifyIfUnique($input: NotificationData!) {
notifyIfUnique(input: $input) { id title importance }
}
""",
"unarchive_many": """
mutation UnarchiveNotifications($ids: [PrefixedID!]!) {
unarchiveNotifications(ids: $ids) {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
"unarchive_all": """
mutation UnarchiveAll($importance: NotificationImportance) {
unarchiveAll(importance: $importance) {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
"recalculate": """
mutation RecalculateOverview {
recalculateOverview {
unread { info warning alert total }
archive { info warning alert total }
}
}
""",
}
DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"}
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
_VALID_IMPORTANCE = {"ALERT", "WARNING", "NORMAL"}
_VALID_IMPORTANCE = {"ALERT", "WARNING", "INFO"}
NOTIFICATION_ACTIONS = Literal[
"overview",
@@ -89,6 +135,11 @@ NOTIFICATION_ACTIONS = Literal[
"delete",
"delete_archived",
"archive_all",
"archive_many",
"create_unique",
"unarchive_many",
"unarchive_all",
"recalculate",
]
if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS:
@@ -108,6 +159,7 @@ def register_notifications_tool(mcp: FastMCP) -> None:
action: NOTIFICATION_ACTIONS,
confirm: bool = False,
notification_id: str | None = None,
notification_ids: list[str] | None = None,
notification_type: str | None = None,
importance: str | None = None,
offset: int = 0,
@@ -129,6 +181,11 @@ def register_notifications_tool(mcp: FastMCP) -> None:
delete - Delete a notification (requires notification_id, notification_type, confirm=True)
delete_archived - Delete all archived notifications (requires confirm=True)
archive_all - Archive all notifications (optional importance filter)
archive_many - Archive multiple notifications by ID (requires notification_ids)
create_unique - Create notification only if no equivalent unread exists (requires title, subject, description, importance)
unarchive_many - Move notifications back to unread (requires notification_ids)
unarchive_all - Move all archived notifications to unread (optional importance filter)
recalculate - Recompute overview counts from disk
"""
if action not in ALL_ACTIONS:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
@@ -212,6 +269,55 @@ def register_notifications_tool(mcp: FastMCP) -> None:
data = await make_graphql_request(MUTATIONS["archive_all"], variables)
return {"success": True, "action": "archive_all", "data": data}
if action == "archive_many":
if not notification_ids:
raise ToolError("notification_ids is required for 'archive_many' action")
data = await make_graphql_request(
MUTATIONS["archive_many"], {"ids": notification_ids}
)
return {"success": True, "action": "archive_many", "data": data}
if action == "create_unique":
if title is None or subject is None or description is None or importance is None:
raise ToolError(
"create_unique requires title, subject, description, and importance"
)
if importance.upper() not in _VALID_IMPORTANCE:
raise ToolError(
f"importance must be one of: {', '.join(sorted(_VALID_IMPORTANCE))}. "
f"Got: '{importance}'"
)
input_data = {
"title": title,
"subject": subject,
"description": description,
"importance": importance.upper(),
}
data = await make_graphql_request(MUTATIONS["create_unique"], {"input": input_data})
notification = data.get("notifyIfUnique")
if notification is None:
return {"success": True, "duplicate": True, "data": None}
return {"success": True, "duplicate": False, "data": notification}
if action == "unarchive_many":
if not notification_ids:
raise ToolError("notification_ids is required for 'unarchive_many' action")
data = await make_graphql_request(
MUTATIONS["unarchive_many"], {"ids": notification_ids}
)
return {"success": True, "action": "unarchive_many", "data": data}
if action == "unarchive_all":
vars_: dict[str, Any] | None = None
if importance:
vars_ = {"importance": importance.upper()}
data = await make_graphql_request(MUTATIONS["unarchive_all"], vars_)
return {"success": True, "action": "unarchive_all", "data": data}
if action == "recalculate":
data = await make_graphql_request(MUTATIONS["recalculate"])
return {"success": True, "action": "recalculate", "data": data}
raise ToolError(f"Unhandled action '{action}' — this is a bug")
logger.info("Notifications tool registered successfully")

View File

@@ -114,56 +114,42 @@ def register_vm_tool(mcp: FastMCP) -> None:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
with tool_error_handler("vm", action, logger):
try:
logger.info(f"Executing unraid_vm action={action}")
logger.info(f"Executing unraid_vm action={action}")
if action == "list":
data = await make_graphql_request(QUERIES["list"])
if data.get("vms"):
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
if isinstance(vms, dict):
vms = [vms]
return {"vms": vms}
return {"vms": []}
if action == "details":
data = await make_graphql_request(QUERIES["details"])
if not data.get("vms"):
raise ToolError("No VM data returned from server")
if action == "list":
data = await make_graphql_request(QUERIES["list"])
if data.get("vms"):
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
if isinstance(vms, dict):
vms = [vms]
for vm in vms:
if (
vm.get("uuid") == vm_id
or vm.get("id") == vm_id
or vm.get("name") == vm_id
):
return dict(vm)
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
return {"vms": vms}
return {"vms": []}
# Mutations
if action in MUTATIONS:
data = await make_graphql_request(MUTATIONS[action], {"id": vm_id})
field = _MUTATION_FIELDS.get(action, action)
if data.get("vm") and field in data["vm"]:
return {
"success": data["vm"][field],
"action": action,
"vm_id": vm_id,
}
raise ToolError(f"Failed to {action} VM or unexpected response")
if action == "details":
data = await make_graphql_request(QUERIES["details"])
if not data.get("vms"):
raise ToolError("No VM data returned from server")
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
if isinstance(vms, dict):
vms = [vms]
for vm in vms:
if vm.get("uuid") == vm_id or vm.get("id") == vm_id or vm.get("name") == vm_id:
return dict(vm)
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
raise ToolError(f"Unhandled action '{action}' — this is a bug")
# Mutations
if action in MUTATIONS:
data = await make_graphql_request(MUTATIONS[action], {"id": vm_id})
field = _MUTATION_FIELDS.get(action, action)
if data.get("vm") and field in data["vm"]:
return {
"success": data["vm"][field],
"action": action,
"vm_id": vm_id,
}
raise ToolError(f"Failed to {action} VM or unexpected response")
except ToolError:
raise
except Exception as e:
if "VMs are not available" in str(e):
raise ToolError(
"VMs not available on this server. Check VM support is enabled."
) from e
raise
raise ToolError(f"Unhandled action '{action}' — this is a bug")
logger.info("VM tool registered successfully")