mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 12:39:24 -07:00
feat: add 5 notification mutations + comprehensive refactors from PR review
New notification actions (archive_many, create_unique, unarchive_many, unarchive_all, recalculate) bring unraid_notifications to 14 actions. Also includes continuation of CodeRabbit/PR review fixes: - Remove redundant try-except in virtualization.py (silent failure fix) - Add QueryCache protocol with get/put/invalidate_all to core/client.py - Refactor subscriptions (manager, diagnostics, resources, utils) - Update config (logging, settings) for improved structure - Expand test coverage: http_layer, safety guards, schema validation - Minor cleanups: array, docker, health, keys tools Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -53,21 +53,22 @@ class OverwriteFileHandler(logging.FileHandler):
|
||||
):
|
||||
try:
|
||||
base_path = Path(self.baseFilename)
|
||||
if base_path.exists():
|
||||
file_size = base_path.stat().st_size
|
||||
if file_size >= self.max_bytes:
|
||||
# Close current stream
|
||||
if self.stream:
|
||||
self.stream.close()
|
||||
|
||||
# Remove the old file and start fresh
|
||||
if base_path.exists():
|
||||
base_path.unlink()
|
||||
|
||||
# Reopen with truncate mode
|
||||
file_size = base_path.stat().st_size if base_path.exists() else 0
|
||||
if file_size >= self.max_bytes:
|
||||
old_stream = self.stream
|
||||
self.stream = None
|
||||
try:
|
||||
old_stream.close()
|
||||
base_path.unlink(missing_ok=True)
|
||||
self.stream = self._open()
|
||||
except OSError:
|
||||
# Recovery: attempt to reopen even if unlink failed
|
||||
try:
|
||||
self.stream = self._open()
|
||||
except OSError:
|
||||
self.stream = old_stream # Last resort: restore original
|
||||
|
||||
# Log a marker that the file was reset
|
||||
if self.stream is not None:
|
||||
reset_record = logging.LogRecord(
|
||||
name="UnraidMCPServer.Logging",
|
||||
level=logging.INFO,
|
||||
@@ -184,27 +185,8 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
|
||||
|
||||
fastmcp_logger.setLevel(numeric_log_level)
|
||||
|
||||
# Also configure the root logger to catch any other logs
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
root_logger.propagate = False
|
||||
|
||||
# Rich Console Handler for root logger
|
||||
root_console_handler = RichHandler(
|
||||
console=console,
|
||||
show_time=True,
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=False,
|
||||
markup=True,
|
||||
)
|
||||
root_console_handler.setLevel(numeric_log_level)
|
||||
root_logger.addHandler(root_console_handler)
|
||||
|
||||
# Reuse the shared file handler for root logger
|
||||
root_logger.addHandler(_shared_file_handler)
|
||||
root_logger.setLevel(numeric_log_level)
|
||||
# Set root logger level to avoid suppressing library warnings entirely
|
||||
logging.getLogger().setLevel(numeric_log_level)
|
||||
|
||||
return fastmcp_logger
|
||||
|
||||
|
||||
@@ -36,8 +36,27 @@ for dotenv_path in dotenv_paths:
|
||||
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
|
||||
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
|
||||
|
||||
|
||||
# Server Configuration
|
||||
UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970"))
|
||||
def _parse_port(env_var: str, default: int) -> int:
|
||||
"""Parse a port number from environment variable with validation."""
|
||||
raw = os.getenv(env_var, str(default))
|
||||
try:
|
||||
port = int(raw)
|
||||
except ValueError:
|
||||
import sys
|
||||
|
||||
print(f"FATAL: {env_var}={raw!r} is not a valid integer port number", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not (1 <= port <= 65535):
|
||||
import sys
|
||||
|
||||
print(f"FATAL: {env_var}={port} outside valid port range 1-65535", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return port
|
||||
|
||||
|
||||
UNRAID_MCP_PORT = _parse_port("UNRAID_MCP_PORT", 6970)
|
||||
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0") # noqa: S104 — intentional for Docker
|
||||
UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower()
|
||||
|
||||
@@ -58,7 +77,7 @@ IS_DOCKER = Path("/.dockerenv").exists()
|
||||
LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs"
|
||||
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
||||
|
||||
# Ensure logs directory exists; if creation fails, fall back to /tmp.
|
||||
# Ensure logs directory exists; if creation fails, fall back to PROJECT_ROOT / ".cache" / "logs".
|
||||
try:
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
except OSError:
|
||||
@@ -97,9 +116,11 @@ def get_config_summary() -> dict[str, Any]:
|
||||
"""
|
||||
is_valid, missing = validate_required_config()
|
||||
|
||||
from ..core.utils import safe_display_url
|
||||
|
||||
return {
|
||||
"api_url_configured": bool(UNRAID_API_URL),
|
||||
"api_url_preview": UNRAID_API_URL[:20] + "..." if UNRAID_API_URL else None,
|
||||
"api_url_preview": safe_display_url(UNRAID_API_URL) if UNRAID_API_URL else None,
|
||||
"api_key_configured": bool(UNRAID_API_KEY),
|
||||
"server_host": UNRAID_MCP_HOST,
|
||||
"server_port": UNRAID_MCP_PORT,
|
||||
@@ -110,5 +131,7 @@ def get_config_summary() -> dict[str, Any]:
|
||||
"config_valid": is_valid,
|
||||
"missing_config": missing if not is_valid else None,
|
||||
}
|
||||
|
||||
|
||||
# Re-export application version from a single source of truth.
|
||||
VERSION = APP_VERSION
|
||||
|
||||
@@ -51,9 +51,7 @@ def _is_sensitive_key(key: str) -> bool:
|
||||
def redact_sensitive(obj: Any) -> Any:
|
||||
"""Recursively redact sensitive values from nested dicts/lists."""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()
|
||||
}
|
||||
return {k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [redact_sensitive(item) for item in obj]
|
||||
return obj
|
||||
@@ -149,10 +147,16 @@ class _QueryCache:
|
||||
Keyed by a hash of (query, variables). Entries expire after _CACHE_TTL_SECONDS.
|
||||
Only caches responses for queries whose operation name is in _CACHEABLE_QUERY_PREFIXES.
|
||||
Mutation requests always bypass the cache.
|
||||
|
||||
Thread-safe via asyncio.Lock. Bounded to _MAX_ENTRIES with FIFO eviction (oldest
|
||||
expiry timestamp evicted first when the store is full).
|
||||
"""
|
||||
|
||||
_MAX_ENTRIES: Final[int] = 256
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||
self._lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(query: str, variables: dict[str, Any] | None) -> str:
|
||||
@@ -170,26 +174,32 @@ class _QueryCache:
|
||||
return False
|
||||
return match.group(1) in _CACHEABLE_QUERY_PREFIXES
|
||||
|
||||
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
async def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Return cached result if present and not expired, else None."""
|
||||
key = self._cache_key(query, variables)
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, data = entry
|
||||
if time.monotonic() > expires_at:
|
||||
del self._store[key]
|
||||
return None
|
||||
return data
|
||||
async with self._lock:
|
||||
key = self._cache_key(query, variables)
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, data = entry
|
||||
if time.monotonic() > expires_at:
|
||||
del self._store[key]
|
||||
return None
|
||||
return data
|
||||
|
||||
def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
|
||||
"""Store a query result with TTL expiry."""
|
||||
key = self._cache_key(query, variables)
|
||||
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
|
||||
async def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
|
||||
"""Store a query result with TTL expiry, evicting oldest entry if at capacity."""
|
||||
async with self._lock:
|
||||
if len(self._store) >= self._MAX_ENTRIES:
|
||||
oldest_key = min(self._store, key=lambda k: self._store[k][0])
|
||||
del self._store[oldest_key]
|
||||
key = self._cache_key(query, variables)
|
||||
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
|
||||
|
||||
def invalidate_all(self) -> None:
|
||||
async def invalidate_all(self) -> None:
|
||||
"""Clear the entire cache (called after mutations)."""
|
||||
self._store.clear()
|
||||
async with self._lock:
|
||||
self._store.clear()
|
||||
|
||||
|
||||
_query_cache = _QueryCache()
|
||||
@@ -310,10 +320,10 @@ async def make_graphql_request(
|
||||
if not UNRAID_API_KEY:
|
||||
raise ToolError("UNRAID_API_KEY not configured")
|
||||
|
||||
# Check TTL cache for stable read-only queries
|
||||
# Check TTL cache — short-circuits rate limiter on hits
|
||||
is_mutation = query.lstrip().startswith("mutation")
|
||||
if not is_mutation and _query_cache.is_cacheable(query):
|
||||
cached = _query_cache.get(query, variables)
|
||||
cached = await _query_cache.get(query, variables)
|
||||
if cached is not None:
|
||||
logger.debug("Returning cached response for query")
|
||||
return cached
|
||||
@@ -399,9 +409,9 @@ async def make_graphql_request(
|
||||
|
||||
# Invalidate cache on mutations; cache eligible query results
|
||||
if is_mutation:
|
||||
_query_cache.invalidate_all()
|
||||
await _query_cache.invalidate_all()
|
||||
elif _query_cache.is_cacheable(query):
|
||||
_query_cache.put(query, variables, result)
|
||||
await _query_cache.put(query, variables, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -11,12 +11,19 @@ import sys
|
||||
|
||||
async def shutdown_cleanup() -> None:
|
||||
"""Cleanup resources on server shutdown."""
|
||||
try:
|
||||
from .subscriptions.manager import subscription_manager
|
||||
|
||||
await subscription_manager.stop_all()
|
||||
except Exception as e:
|
||||
print(f"Error stopping subscriptions during cleanup: {e}", file=sys.stderr)
|
||||
|
||||
try:
|
||||
from .core.client import close_http_client
|
||||
|
||||
await close_http_client()
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup: {e}")
|
||||
print(f"Error during cleanup: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
def _run_shutdown_cleanup() -> None:
|
||||
|
||||
@@ -10,8 +10,6 @@ from fastmcp import FastMCP
|
||||
|
||||
from .config.logging import logger
|
||||
from .config.settings import (
|
||||
UNRAID_API_KEY,
|
||||
UNRAID_API_URL,
|
||||
UNRAID_MCP_HOST,
|
||||
UNRAID_MCP_PORT,
|
||||
UNRAID_MCP_TRANSPORT,
|
||||
@@ -86,20 +84,10 @@ def run_server() -> None:
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Log configuration
|
||||
if UNRAID_API_URL:
|
||||
logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...")
|
||||
else:
|
||||
logger.warning("UNRAID_API_URL not found in environment or .env file.")
|
||||
# Log configuration (delegated to shared function)
|
||||
from .config.logging import log_configuration_status
|
||||
|
||||
if UNRAID_API_KEY:
|
||||
logger.info("UNRAID_API_KEY loaded: ****")
|
||||
else:
|
||||
logger.warning("UNRAID_API_KEY not found in environment or .env file.")
|
||||
|
||||
logger.info(f"UNRAID_MCP_PORT set to: {UNRAID_MCP_PORT}")
|
||||
logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}")
|
||||
logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}")
|
||||
log_configuration_status(logger)
|
||||
|
||||
if UNRAID_VERIFY_SSL is False:
|
||||
logger.warning(
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..core.exceptions import ToolError
|
||||
from ..core.utils import safe_display_url
|
||||
from .manager import subscription_manager
|
||||
from .resources import ensure_subscriptions_started
|
||||
from .utils import build_ws_ssl_context, build_ws_url
|
||||
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
||||
@@ -187,8 +187,10 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
# Get comprehensive status
|
||||
status = await subscription_manager.get_subscription_status()
|
||||
|
||||
# Initialize connection issues list with proper type
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
# Analyze connection issues and error counts via the shared helper.
|
||||
# This ensures "invalid_uri" and all other error states are counted
|
||||
# consistently with the health tool's _diagnose_subscriptions path.
|
||||
error_count, connection_issues = _analyze_subscription_status(status)
|
||||
|
||||
# Add environment info with explicit typing
|
||||
diagnostic_info: dict[str, Any] = {
|
||||
@@ -210,7 +212,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
),
|
||||
"active_count": len(subscription_manager.active_subscriptions),
|
||||
"with_data": len(subscription_manager.resource_data),
|
||||
"in_error_state": 0,
|
||||
"in_error_state": error_count,
|
||||
"connection_issues": connection_issues,
|
||||
},
|
||||
}
|
||||
@@ -219,23 +221,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
with contextlib.suppress(ValueError):
|
||||
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
|
||||
|
||||
# Analyze issues
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
connection_state = runtime.get("connection_state", "unknown")
|
||||
|
||||
if connection_state in ["error", "auth_failed", "timeout", "max_retries_exceeded"]:
|
||||
diagnostic_info["summary"]["in_error_state"] += 1
|
||||
|
||||
if runtime.get("last_error"):
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
"state": connection_state,
|
||||
"error": runtime["last_error"],
|
||||
}
|
||||
)
|
||||
|
||||
# Add troubleshooting recommendations
|
||||
recommendations: list[str] = []
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
elif (
|
||||
key == "content"
|
||||
and isinstance(value, str)
|
||||
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
||||
and len(value) > _MAX_RESOURCE_DATA_BYTES # fast pre-check on char count
|
||||
):
|
||||
lines = value.splitlines()
|
||||
original_line_count = len(lines)
|
||||
@@ -54,19 +54,15 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
||||
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
|
||||
|
||||
# Enforce byte cap while preserving whole-line boundaries where possible.
|
||||
truncated = "\n".join(lines)
|
||||
truncated_bytes = truncated.encode("utf-8", errors="replace")
|
||||
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
|
||||
lines = lines[1:]
|
||||
truncated = "\n".join(lines)
|
||||
truncated_bytes = truncated.encode("utf-8", errors="replace")
|
||||
|
||||
# Last resort: if a single line still exceeds cap, hard-cap bytes.
|
||||
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
|
||||
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# Encode once and slice bytes instead of O(n²) line-trim loop
|
||||
encoded = truncated.encode("utf-8", errors="replace")
|
||||
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
|
||||
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
|
||||
# Strip partial first line that may have been cut mid-character
|
||||
nl_pos = truncated.find("\n")
|
||||
if nl_pos != -1:
|
||||
truncated = truncated[nl_pos + 1 :]
|
||||
|
||||
logger.warning(
|
||||
f"[RESOURCE] Capped log content from {original_line_count} to "
|
||||
@@ -202,6 +198,16 @@ class SubscriptionManager:
|
||||
else:
|
||||
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all active subscriptions (called during server shutdown)."""
|
||||
subscription_names = list(self.active_subscriptions.keys())
|
||||
for name in subscription_names:
|
||||
try:
|
||||
await self.stop_subscription(name)
|
||||
except Exception as e:
|
||||
logger.error(f"[SHUTDOWN] Error stopping subscription '{name}': {e}", exc_info=True)
|
||||
logger.info(f"[SHUTDOWN] Stopped {len(subscription_names)} subscription(s)")
|
||||
|
||||
async def _subscription_loop(
|
||||
self, subscription_name: str, query: str, variables: dict[str, Any] | None
|
||||
) -> None:
|
||||
@@ -512,9 +518,11 @@ class SubscriptionManager:
|
||||
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
|
||||
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
|
||||
)
|
||||
|
||||
# Calculate backoff delay
|
||||
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
||||
# Only escalate backoff when connection was NOT stable
|
||||
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
||||
else:
|
||||
# No connection was established — escalate backoff
|
||||
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
||||
logger.info(
|
||||
f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds..."
|
||||
)
|
||||
|
||||
@@ -4,8 +4,10 @@ This module defines MCP resources that bridge between the subscription manager
|
||||
and the MCP protocol, providing fallback queries when subscription data is unavailable.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
import anyio
|
||||
from fastmcp import FastMCP
|
||||
@@ -16,22 +18,29 @@ from .manager import subscription_manager
|
||||
|
||||
# Global flag to track subscription startup
|
||||
_subscriptions_started = False
|
||||
_startup_lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||
|
||||
|
||||
async def ensure_subscriptions_started() -> None:
|
||||
"""Ensure subscriptions are started, called from async context."""
|
||||
global _subscriptions_started
|
||||
|
||||
# Fast-path: skip lock if already started
|
||||
if _subscriptions_started:
|
||||
return
|
||||
|
||||
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
|
||||
try:
|
||||
await autostart_subscriptions()
|
||||
_subscriptions_started = True
|
||||
logger.info("[STARTUP] Subscriptions started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
|
||||
# Slow-path: acquire lock for initialization (double-checked locking)
|
||||
async with _startup_lock:
|
||||
if _subscriptions_started:
|
||||
return
|
||||
|
||||
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
|
||||
try:
|
||||
await autostart_subscriptions()
|
||||
_subscriptions_started = True
|
||||
logger.info("[STARTUP] Subscriptions started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def autostart_subscriptions() -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Shared utilities for the subscription system."""
|
||||
|
||||
import ssl as _ssl
|
||||
from typing import Any
|
||||
|
||||
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL
|
||||
|
||||
@@ -52,3 +53,37 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = _ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
|
||||
def _analyze_subscription_status(
|
||||
status: dict[str, Any],
|
||||
) -> tuple[int, list[dict[str, Any]]]:
|
||||
"""Analyze subscription status dict, returning error count and connection issues.
|
||||
|
||||
This is the canonical, shared implementation used by both the health tool
|
||||
and the subscription diagnostics tool.
|
||||
|
||||
Args:
|
||||
status: Dict of subscription name -> status info from get_subscription_status().
|
||||
|
||||
Returns:
|
||||
Tuple of (error_count, connection_issues_list).
|
||||
"""
|
||||
error_count = 0
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
conn_state = runtime.get("connection_state", "unknown")
|
||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"):
|
||||
error_count += 1
|
||||
if runtime.get("last_error"):
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
"state": conn_state,
|
||||
"error": runtime["last_error"],
|
||||
}
|
||||
)
|
||||
|
||||
return error_count, connection_issues
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""MCP tools organized by functional domain.
|
||||
|
||||
10 consolidated tools with ~90 actions total:
|
||||
10 consolidated tools with 76 actions total:
|
||||
unraid_info - System information queries (19 actions)
|
||||
unraid_array - Array operations and power management (12 actions)
|
||||
unraid_array - Array operations and parity management (5 actions)
|
||||
unraid_storage - Storage, disks, and logs (6 actions)
|
||||
unraid_docker - Docker container management (15 actions)
|
||||
unraid_vm - Virtual machine management (9 actions)
|
||||
unraid_notifications - Notification management (9 actions)
|
||||
unraid_rclone - Cloud storage remotes (4 actions)
|
||||
unraid_users - User management (8 actions)
|
||||
unraid_users - User management (1 action)
|
||||
unraid_keys - API key management (5 actions)
|
||||
unraid_health - Health monitoring and diagnostics (3 actions)
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,7 @@ def register_array_tool(mcp: FastMCP) -> None:
|
||||
"""Manage Unraid array parity checks.
|
||||
|
||||
Actions:
|
||||
parity_start - Start parity check (optional correct=True to fix errors)
|
||||
parity_start - Start parity check (correct=True to fix errors, correct=False for read-only; required)
|
||||
parity_pause - Pause running parity check
|
||||
parity_resume - Resume paused parity check
|
||||
parity_cancel - Cancel running parity check
|
||||
|
||||
@@ -233,8 +233,8 @@ async def _resolve_container_id(container_id: str, *, strict: bool = False) -> s
|
||||
data = await make_graphql_request(list_query)
|
||||
containers = safe_get(data, "docker", "containers", default=[])
|
||||
|
||||
# Short hex prefix: match by ID prefix before trying name matching
|
||||
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
|
||||
# Short hex prefix: match by ID prefix before trying name matching (strict bypasses this)
|
||||
if not strict and _DOCKER_SHORT_ID_PATTERN.match(container_id):
|
||||
id_lower = container_id.lower()
|
||||
matches: list[dict[str, Any]] = []
|
||||
for c in containers:
|
||||
|
||||
@@ -21,6 +21,7 @@ from ..config.settings import (
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
from ..core.utils import safe_display_url
|
||||
from ..subscriptions.utils import _analyze_subscription_status
|
||||
|
||||
|
||||
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
|
||||
@@ -218,42 +219,6 @@ async def _comprehensive_check() -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _analyze_subscription_status(
|
||||
status: dict[str, Any],
|
||||
) -> tuple[int, list[dict[str, Any]]]:
|
||||
"""Analyze subscription status dict, returning error count and connection issues.
|
||||
|
||||
This is the canonical implementation of subscription status analysis.
|
||||
TODO: subscriptions/diagnostics.py has a similar status-analysis pattern
|
||||
in diagnose_subscriptions(). That module could import and call this helper
|
||||
directly to avoid divergence. See Code-H05.
|
||||
|
||||
Args:
|
||||
status: Dict of subscription name -> status info from get_subscription_status().
|
||||
|
||||
Returns:
|
||||
Tuple of (error_count, connection_issues_list).
|
||||
"""
|
||||
error_count = 0
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
conn_state = runtime.get("connection_state", "unknown")
|
||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
|
||||
error_count += 1
|
||||
if runtime.get("last_error"):
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
"state": conn_state,
|
||||
"error": runtime["last_error"],
|
||||
}
|
||||
)
|
||||
|
||||
return error_count, connection_issues
|
||||
|
||||
|
||||
async def _diagnose_subscriptions() -> dict[str, Any]:
|
||||
"""Import and run subscription diagnostics."""
|
||||
try:
|
||||
|
||||
@@ -114,10 +114,14 @@ def register_keys_tool(mcp: FastMCP) -> None:
|
||||
if permissions is not None:
|
||||
input_data["permissions"] = permissions
|
||||
data = await make_graphql_request(MUTATIONS["create"], {"input": input_data})
|
||||
return {
|
||||
"success": True,
|
||||
"key": (data.get("apiKey") or {}).get("create", {}),
|
||||
}
|
||||
created_key = (data.get("apiKey") or {}).get("create")
|
||||
if not created_key:
|
||||
return {
|
||||
"success": False,
|
||||
"key": {},
|
||||
"message": "API key creation failed: no data returned from server",
|
||||
}
|
||||
return {"success": True, "key": created_key}
|
||||
|
||||
if action == "update":
|
||||
if not key_id:
|
||||
@@ -128,10 +132,14 @@ def register_keys_tool(mcp: FastMCP) -> None:
|
||||
if roles is not None:
|
||||
input_data["roles"] = roles
|
||||
data = await make_graphql_request(MUTATIONS["update"], {"input": input_data})
|
||||
return {
|
||||
"success": True,
|
||||
"key": (data.get("apiKey") or {}).get("update", {}),
|
||||
}
|
||||
updated_key = (data.get("apiKey") or {}).get("update")
|
||||
if not updated_key:
|
||||
return {
|
||||
"success": False,
|
||||
"key": {},
|
||||
"message": "API key update failed: no data returned from server",
|
||||
}
|
||||
return {"success": True, "key": updated_key}
|
||||
|
||||
if action == "delete":
|
||||
if not key_id:
|
||||
|
||||
@@ -50,34 +50,80 @@ MUTATIONS: dict[str, str] = {
|
||||
""",
|
||||
"archive": """
|
||||
mutation ArchiveNotification($id: PrefixedID!) {
|
||||
archiveNotification(id: $id)
|
||||
archiveNotification(id: $id) { id title importance }
|
||||
}
|
||||
""",
|
||||
"unread": """
|
||||
mutation UnreadNotification($id: PrefixedID!) {
|
||||
unreadNotification(id: $id)
|
||||
unreadNotification(id: $id) { id title importance }
|
||||
}
|
||||
""",
|
||||
"delete": """
|
||||
mutation DeleteNotification($id: PrefixedID!, $type: NotificationType!) {
|
||||
deleteNotification(id: $id, type: $type)
|
||||
deleteNotification(id: $id, type: $type) {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
"delete_archived": """
|
||||
mutation DeleteArchivedNotifications {
|
||||
deleteArchivedNotifications
|
||||
deleteArchivedNotifications {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
"archive_all": """
|
||||
mutation ArchiveAllNotifications($importance: NotificationImportance) {
|
||||
archiveAll(importance: $importance)
|
||||
archiveAll(importance: $importance) {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
"archive_many": """
|
||||
mutation ArchiveNotifications($ids: [PrefixedID!]!) {
|
||||
archiveNotifications(ids: $ids) {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
"create_unique": """
|
||||
mutation NotifyIfUnique($input: NotificationData!) {
|
||||
notifyIfUnique(input: $input) { id title importance }
|
||||
}
|
||||
""",
|
||||
"unarchive_many": """
|
||||
mutation UnarchiveNotifications($ids: [PrefixedID!]!) {
|
||||
unarchiveNotifications(ids: $ids) {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
"unarchive_all": """
|
||||
mutation UnarchiveAll($importance: NotificationImportance) {
|
||||
unarchiveAll(importance: $importance) {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
"recalculate": """
|
||||
mutation RecalculateOverview {
|
||||
recalculateOverview {
|
||||
unread { info warning alert total }
|
||||
archive { info warning alert total }
|
||||
}
|
||||
}
|
||||
""",
|
||||
}
|
||||
|
||||
DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"}
|
||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||
_VALID_IMPORTANCE = {"ALERT", "WARNING", "NORMAL"}
|
||||
_VALID_IMPORTANCE = {"ALERT", "WARNING", "INFO"}
|
||||
|
||||
NOTIFICATION_ACTIONS = Literal[
|
||||
"overview",
|
||||
@@ -89,6 +135,11 @@ NOTIFICATION_ACTIONS = Literal[
|
||||
"delete",
|
||||
"delete_archived",
|
||||
"archive_all",
|
||||
"archive_many",
|
||||
"create_unique",
|
||||
"unarchive_many",
|
||||
"unarchive_all",
|
||||
"recalculate",
|
||||
]
|
||||
|
||||
if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS:
|
||||
@@ -108,6 +159,7 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
action: NOTIFICATION_ACTIONS,
|
||||
confirm: bool = False,
|
||||
notification_id: str | None = None,
|
||||
notification_ids: list[str] | None = None,
|
||||
notification_type: str | None = None,
|
||||
importance: str | None = None,
|
||||
offset: int = 0,
|
||||
@@ -129,6 +181,11 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
delete - Delete a notification (requires notification_id, notification_type, confirm=True)
|
||||
delete_archived - Delete all archived notifications (requires confirm=True)
|
||||
archive_all - Archive all notifications (optional importance filter)
|
||||
archive_many - Archive multiple notifications by ID (requires notification_ids)
|
||||
create_unique - Create notification only if no equivalent unread exists (requires title, subject, description, importance)
|
||||
unarchive_many - Move notifications back to unread (requires notification_ids)
|
||||
unarchive_all - Move all archived notifications to unread (optional importance filter)
|
||||
recalculate - Recompute overview counts from disk
|
||||
"""
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
@@ -212,6 +269,55 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
data = await make_graphql_request(MUTATIONS["archive_all"], variables)
|
||||
return {"success": True, "action": "archive_all", "data": data}
|
||||
|
||||
if action == "archive_many":
|
||||
if not notification_ids:
|
||||
raise ToolError("notification_ids is required for 'archive_many' action")
|
||||
data = await make_graphql_request(
|
||||
MUTATIONS["archive_many"], {"ids": notification_ids}
|
||||
)
|
||||
return {"success": True, "action": "archive_many", "data": data}
|
||||
|
||||
if action == "create_unique":
|
||||
if title is None or subject is None or description is None or importance is None:
|
||||
raise ToolError(
|
||||
"create_unique requires title, subject, description, and importance"
|
||||
)
|
||||
if importance.upper() not in _VALID_IMPORTANCE:
|
||||
raise ToolError(
|
||||
f"importance must be one of: {', '.join(sorted(_VALID_IMPORTANCE))}. "
|
||||
f"Got: '{importance}'"
|
||||
)
|
||||
input_data = {
|
||||
"title": title,
|
||||
"subject": subject,
|
||||
"description": description,
|
||||
"importance": importance.upper(),
|
||||
}
|
||||
data = await make_graphql_request(MUTATIONS["create_unique"], {"input": input_data})
|
||||
notification = data.get("notifyIfUnique")
|
||||
if notification is None:
|
||||
return {"success": True, "duplicate": True, "data": None}
|
||||
return {"success": True, "duplicate": False, "data": notification}
|
||||
|
||||
if action == "unarchive_many":
|
||||
if not notification_ids:
|
||||
raise ToolError("notification_ids is required for 'unarchive_many' action")
|
||||
data = await make_graphql_request(
|
||||
MUTATIONS["unarchive_many"], {"ids": notification_ids}
|
||||
)
|
||||
return {"success": True, "action": "unarchive_many", "data": data}
|
||||
|
||||
if action == "unarchive_all":
|
||||
vars_: dict[str, Any] | None = None
|
||||
if importance:
|
||||
vars_ = {"importance": importance.upper()}
|
||||
data = await make_graphql_request(MUTATIONS["unarchive_all"], vars_)
|
||||
return {"success": True, "action": "unarchive_all", "data": data}
|
||||
|
||||
if action == "recalculate":
|
||||
data = await make_graphql_request(MUTATIONS["recalculate"])
|
||||
return {"success": True, "action": "recalculate", "data": data}
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
logger.info("Notifications tool registered successfully")
|
||||
|
||||
@@ -114,56 +114,42 @@ def register_vm_tool(mcp: FastMCP) -> None:
|
||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||
|
||||
with tool_error_handler("vm", action, logger):
|
||||
try:
|
||||
logger.info(f"Executing unraid_vm action={action}")
|
||||
logger.info(f"Executing unraid_vm action={action}")
|
||||
|
||||
if action == "list":
|
||||
data = await make_graphql_request(QUERIES["list"])
|
||||
if data.get("vms"):
|
||||
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||
if isinstance(vms, dict):
|
||||
vms = [vms]
|
||||
return {"vms": vms}
|
||||
return {"vms": []}
|
||||
|
||||
if action == "details":
|
||||
data = await make_graphql_request(QUERIES["details"])
|
||||
if not data.get("vms"):
|
||||
raise ToolError("No VM data returned from server")
|
||||
if action == "list":
|
||||
data = await make_graphql_request(QUERIES["list"])
|
||||
if data.get("vms"):
|
||||
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||
if isinstance(vms, dict):
|
||||
vms = [vms]
|
||||
for vm in vms:
|
||||
if (
|
||||
vm.get("uuid") == vm_id
|
||||
or vm.get("id") == vm_id
|
||||
or vm.get("name") == vm_id
|
||||
):
|
||||
return dict(vm)
|
||||
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
|
||||
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
|
||||
return {"vms": vms}
|
||||
return {"vms": []}
|
||||
|
||||
# Mutations
|
||||
if action in MUTATIONS:
|
||||
data = await make_graphql_request(MUTATIONS[action], {"id": vm_id})
|
||||
field = _MUTATION_FIELDS.get(action, action)
|
||||
if data.get("vm") and field in data["vm"]:
|
||||
return {
|
||||
"success": data["vm"][field],
|
||||
"action": action,
|
||||
"vm_id": vm_id,
|
||||
}
|
||||
raise ToolError(f"Failed to {action} VM or unexpected response")
|
||||
if action == "details":
|
||||
data = await make_graphql_request(QUERIES["details"])
|
||||
if not data.get("vms"):
|
||||
raise ToolError("No VM data returned from server")
|
||||
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||
if isinstance(vms, dict):
|
||||
vms = [vms]
|
||||
for vm in vms:
|
||||
if vm.get("uuid") == vm_id or vm.get("id") == vm_id or vm.get("name") == vm_id:
|
||||
return dict(vm)
|
||||
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
|
||||
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
# Mutations
|
||||
if action in MUTATIONS:
|
||||
data = await make_graphql_request(MUTATIONS[action], {"id": vm_id})
|
||||
field = _MUTATION_FIELDS.get(action, action)
|
||||
if data.get("vm") and field in data["vm"]:
|
||||
return {
|
||||
"success": data["vm"][field],
|
||||
"action": action,
|
||||
"vm_id": vm_id,
|
||||
}
|
||||
raise ToolError(f"Failed to {action} VM or unexpected response")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if "VMs are not available" in str(e):
|
||||
raise ToolError(
|
||||
"VMs not available on this server. Check VM support is enabled."
|
||||
) from e
|
||||
raise
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
logger.info("VM tool registered successfully")
|
||||
|
||||
Reference in New Issue
Block a user