fix: address 54 MEDIUM/LOW priority PR review issues

Comprehensive fixes across Python code, shell scripts, and documentation
addressing all remaining MEDIUM and LOW priority review comments.

Python Code Fixes (27 fixes):
- tools/info.py: Simplified dispatch with lookup tables, defensive guards,
  CPU fallback formatting, !s conversion flags, module-level sync assertion
- tools/docker.py: Case-insensitive container ID regex, keyword-only confirm,
  module-level ALL_ACTIONS constant
- tools/virtualization.py: Normalized single-VM dict responses, unified
  list/details queries
- core/client.py: Fixed HTTP client singleton race condition, compound key
  substring matching for sensitive data redaction
- subscriptions/: Extracted SSL context creation to shared helper in utils.py,
  replaced deprecated ssl._create_unverified_context API
- tools/array.py: Renamed parity_history to parity_status, hoisted ALL_ACTIONS
- tools/storage.py: Fixed dict(None) risks, temperature 0 falsiness bug
- tools/notifications.py, keys.py, rclone.py: Fixed dict(None) TypeError risks
- tests/: Fixed generator type annotations, added coverage for compound keys

Shell Script Fixes (13 fixes):
- dashboard.sh: Dynamic server discovery, conditional debug output, null-safe
  jq, notification count guard order, removed unused variables
- unraid-query.sh: Proper JSON escaping via jq, --ignore-errors and --insecure
  CLI flags, TLS verification now on by default
- validate-marketplace.sh: Removed unused YELLOW variable, defensive jq,
  simplified repository URL output

Documentation Fixes (24+ fixes):
- Version consistency: Updated all references to v0.2.0 across pyproject.toml,
  plugin.json, marketplace.json, MARKETPLACE.md, __init__.py, README files
- Tool count updates: Changed all "26 tools" references to "10 tools, 90 actions"
- Markdown lint: Fixed MD022, MD031, MD047 issues across multiple files
- Research docs: Fixed auth headers, removed web artifacts, corrected stale info
- Skills docs: Fixed query examples, endpoint counts, env var references

All 227 tests pass, ruff and ty checks clean.
This commit is contained in:
Jacob Magar
2026-02-15 17:09:31 -05:00
parent 6bbe46879e
commit 37e9424a5c
58 changed files with 1333 additions and 1175 deletions

View File

@@ -4,4 +4,4 @@ A modular MCP (Model Context Protocol) server that provides tools to interact
with an Unraid server's GraphQL API.
"""
__version__ = "0.1.0"
__version__ = "0.2.0"

View File

@@ -5,8 +5,8 @@ that cap at 10MB and start over (no rotation) for consistent use across all modu
"""
import logging
import os
from datetime import datetime
from pathlib import Path
import pytz
from rich.align import Align
@@ -16,6 +16,7 @@ from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text
try:
from fastmcp.utilities.logging import get_logger as get_fastmcp_logger
FASTMCP_AVAILABLE = True
@@ -24,6 +25,7 @@ except ImportError:
from .settings import LOG_FILE_PATH, LOG_LEVEL_STR
# Global Rich console for consistent formatting
console = Console(stderr=True, force_terminal=True)
@@ -31,7 +33,7 @@ console = Console(stderr=True, force_terminal=True)
class OverwriteFileHandler(logging.FileHandler):
"""Custom file handler that overwrites the log file when it reaches max size."""
def __init__(self, filename, max_bytes=10*1024*1024, mode='a', encoding=None, delay=False):
def __init__(self, filename, max_bytes=10*1024*1024, mode="a", encoding=None, delay=False):
"""Initialize the handler.
Args:
@@ -47,18 +49,19 @@ class OverwriteFileHandler(logging.FileHandler):
def emit(self, record):
"""Emit a record, checking file size and overwriting if needed."""
# Check file size before writing
if self.stream and hasattr(self.stream, 'name'):
if self.stream and hasattr(self.stream, "name"):
try:
if os.path.exists(self.baseFilename):
file_size = os.path.getsize(self.baseFilename)
base_path = Path(self.baseFilename)
if base_path.exists():
file_size = base_path.stat().st_size
if file_size >= self.max_bytes:
# Close current stream
if self.stream:
self.stream.close()
# Remove the old file and start fresh
if os.path.exists(self.baseFilename):
os.remove(self.baseFilename)
if base_path.exists():
base_path.unlink()
# Reopen with truncate mode
self.stream = self._open()
@@ -75,9 +78,10 @@ class OverwriteFileHandler(logging.FileHandler):
)
super().emit(reset_record)
except OSError:
# If there's an issue checking file size, just continue normally
pass
except OSError as e:
import sys
print(f"WARNING: Log file size check failed: {e}. Continuing without rotation.",
file=sys.stderr)
# Emit the original record
super().emit(record)
@@ -119,11 +123,11 @@ def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
file_handler = OverwriteFileHandler(
LOG_FILE_PATH,
max_bytes=10*1024*1024,
encoding='utf-8'
encoding="utf-8"
)
file_handler.setLevel(numeric_log_level)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s'
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
@@ -163,11 +167,11 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
file_handler = OverwriteFileHandler(
LOG_FILE_PATH,
max_bytes=10*1024*1024,
encoding='utf-8'
encoding="utf-8"
)
file_handler.setLevel(numeric_log_level)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s'
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
)
file_handler.setFormatter(file_formatter)
fastmcp_logger.addHandler(file_handler)
@@ -196,7 +200,7 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
root_file_handler = OverwriteFileHandler(
LOG_FILE_PATH,
max_bytes=10*1024*1024,
encoding='utf-8'
encoding="utf-8"
)
root_file_handler.setLevel(numeric_log_level)
root_file_handler.setFormatter(file_formatter)
@@ -225,12 +229,12 @@ def log_configuration_status(logger: logging.Logger) -> None:
config = get_config_summary()
# Log configuration status
if config['api_url_configured']:
if config["api_url_configured"]:
logger.info(f"UNRAID_API_URL loaded: {config['api_url_preview']}")
else:
logger.warning("UNRAID_API_URL not found in environment or .env file.")
if config['api_key_configured']:
if config["api_key_configured"]:
logger.info("UNRAID_API_KEY loaded: ****") # Don't log the key itself
else:
logger.warning("UNRAID_API_KEY not found in environment or .env file.")
@@ -240,14 +244,14 @@ def log_configuration_status(logger: logging.Logger) -> None:
logger.info(f"UNRAID_MCP_TRANSPORT set to: {config['transport']}")
logger.info(f"UNRAID_MCP_LOG_LEVEL set to: {config['log_level']}")
if not config['config_valid']:
if not config["config_valid"]:
logger.error(f"Missing required configuration: {config['missing_config']}")
# Development logging helpers for Rich formatting
def get_est_timestamp() -> str:
"""Get current timestamp in EST timezone with YY/MM/DD format."""
est = pytz.timezone('US/Eastern')
est = pytz.timezone("US/Eastern")
now = datetime.now(est)
return now.strftime("%y/%m/%d %H:%M:%S")
@@ -271,7 +275,7 @@ def log_with_level_and_indent(message: str, level: str = "info", indent: int = 0
"error": {"color": "#BF616A", "icon": "", "style": "bold"}, # Nordic red
"warning": {"color": "#EBCB8B", "icon": "⚠️", "style": ""}, # Nordic yellow
"success": {"color": "#A3BE8C", "icon": "", "style": "bold"}, # Nordic green
"info": {"color": "#5E81AC", "icon": "", "style": "bold"}, # Nordic blue (bold)
"info": {"color": "#5E81AC", "icon": "\u2139\ufe0f", "style": "bold"}, # Nordic blue (bold)
"status": {"color": "#81A1C1", "icon": "🔍", "style": ""}, # Light Nordic blue
"debug": {"color": "#4C566A", "icon": "🐛", "style": ""}, # Nordic dark gray
}
@@ -328,11 +332,7 @@ def log_status(message: str, indent: int = 0) -> None:
if FASTMCP_AVAILABLE:
# Use FastMCP logger with Rich formatting
_fastmcp_logger = configure_fastmcp_logger_with_rich()
if _fastmcp_logger is not None:
logger = _fastmcp_logger
else:
# Fallback to our custom logger if FastMCP configuration fails
logger = setup_logger()
logger = _fastmcp_logger if _fastmcp_logger is not None else setup_logger()
else:
# Fallback to our custom logger if FastMCP is not available
logger = setup_logger()

View File

@@ -10,6 +10,7 @@ from typing import Any
from dotenv import load_dotenv
# Get the script directory (config module location)
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
UNRAID_MCP_DIR = SCRIPT_DIR.parent # /home/user/code/unraid-mcp/unraid_mcp/
@@ -18,10 +19,10 @@ PROJECT_ROOT = UNRAID_MCP_DIR.parent # /home/user/code/unraid-mcp/
# Load environment variables from .env file
# In container: First try /app/.env.local (mounted), then project root .env
dotenv_paths = [
Path('/app/.env.local'), # Container mount point
PROJECT_ROOT / '.env.local', # Project root .env.local
PROJECT_ROOT / '.env', # Project root .env
UNRAID_MCP_DIR / '.env' # Local .env in unraid_mcp/
Path("/app/.env.local"), # Container mount point
PROJECT_ROOT / ".env.local", # Project root .env.local
PROJECT_ROOT / ".env", # Project root .env
UNRAID_MCP_DIR / ".env" # Local .env in unraid_mcp/
]
for dotenv_path in dotenv_paths:
@@ -51,7 +52,7 @@ else: # Path to CA bundle
UNRAID_VERIFY_SSL = raw_verify_ssl
# Logging Configuration
LOG_LEVEL_STR = os.getenv('UNRAID_MCP_LOG_LEVEL', 'INFO').upper()
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
LOGS_DIR = Path("/tmp")
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
@@ -61,8 +62,8 @@ LOGS_DIR.mkdir(parents=True, exist_ok=True)
# HTTP Client Configuration
TIMEOUT_CONFIG = {
'default': 30,
'disk_operations': 90, # Longer timeout for SMART data queries
"default": 30,
"disk_operations": 90, # Longer timeout for SMART data queries
}
@@ -94,15 +95,15 @@ def get_config_summary() -> dict[str, Any]:
is_valid, missing = validate_required_config()
return {
'api_url_configured': bool(UNRAID_API_URL),
'api_url_preview': UNRAID_API_URL[:20] + '...' if UNRAID_API_URL else None,
'api_key_configured': bool(UNRAID_API_KEY),
'server_host': UNRAID_MCP_HOST,
'server_port': UNRAID_MCP_PORT,
'transport': UNRAID_MCP_TRANSPORT,
'ssl_verify': UNRAID_VERIFY_SSL,
'log_level': LOG_LEVEL_STR,
'log_file': str(LOG_FILE_PATH),
'config_valid': is_valid,
'missing_config': missing if not is_valid else None
"api_url_configured": bool(UNRAID_API_URL),
"api_url_preview": UNRAID_API_URL[:20] + "..." if UNRAID_API_URL else None,
"api_key_configured": bool(UNRAID_API_KEY),
"server_host": UNRAID_MCP_HOST,
"server_port": UNRAID_MCP_PORT,
"transport": UNRAID_MCP_TRANSPORT,
"ssl_verify": UNRAID_VERIFY_SSL,
"log_level": LOG_LEVEL_STR,
"log_file": str(LOG_FILE_PATH),
"config_valid": is_valid,
"missing_config": missing if not is_valid else None
}

View File

@@ -20,14 +20,21 @@ from ..config.settings import (
)
from ..core.exceptions import ToolError
# Sensitive keys to redact from debug logs
_SENSITIVE_KEYS = {"password", "key", "secret", "token", "apikey"}
def _is_sensitive_key(key: str) -> bool:
"""Check if a key name contains any sensitive substring."""
key_lower = key.lower()
return any(s in key_lower for s in _SENSITIVE_KEYS)
def _redact_sensitive(obj: Any) -> Any:
"""Recursively redact sensitive values from nested dicts/lists."""
if isinstance(obj, dict):
return {k: ("***" if k.lower() in _SENSITIVE_KEYS else _redact_sensitive(v)) for k, v in obj.items()}
return {k: ("***" if _is_sensitive_key(k) else _redact_sensitive(v)) for k, v in obj.items()}
if isinstance(obj, list):
return [_redact_sensitive(item) for item in obj]
return obj
@@ -35,7 +42,25 @@ def _redact_sensitive(obj: Any) -> Any:
# HTTP timeout configuration
DEFAULT_TIMEOUT = httpx.Timeout(10.0, read=30.0, connect=5.0)
DISK_TIMEOUT = httpx.Timeout(10.0, read=TIMEOUT_CONFIG['disk_operations'], connect=5.0)
DISK_TIMEOUT = httpx.Timeout(10.0, read=TIMEOUT_CONFIG["disk_operations"], connect=5.0)
# Named timeout profiles
_TIMEOUT_PROFILES: dict[str, httpx.Timeout] = {
"default": DEFAULT_TIMEOUT,
"disk_operations": DISK_TIMEOUT,
}
def get_timeout_for_operation(profile: str) -> httpx.Timeout:
"""Get a timeout configuration by profile name.
Args:
profile: Timeout profile name (e.g., "default", "disk_operations")
Returns:
The matching httpx.Timeout, falling back to DEFAULT_TIMEOUT for unknown profiles
"""
return _TIMEOUT_PROFILES.get(profile, DEFAULT_TIMEOUT)
# Global connection pool (module-level singleton)
_http_client: httpx.AsyncClient | None = None
@@ -55,26 +80,54 @@ def is_idempotent_error(error_message: str, operation: str) -> bool:
error_lower = error_message.lower()
# Docker container operation patterns
if operation == 'start':
if operation == "start":
return (
'already started' in error_lower or
'container already running' in error_lower or
'http code 304' in error_lower
"already started" in error_lower or
"container already running" in error_lower or
"http code 304" in error_lower
)
elif operation == 'stop':
if operation == "stop":
return (
'already stopped' in error_lower or
'container already stopped' in error_lower or
'container not running' in error_lower or
'http code 304' in error_lower
"already stopped" in error_lower or
"container already stopped" in error_lower or
"container not running" in error_lower or
"http code 304" in error_lower
)
return False
async def _create_http_client() -> httpx.AsyncClient:
"""Create a new HTTP client instance with connection pooling.
Returns:
A new AsyncClient configured for Unraid API communication
"""
return httpx.AsyncClient(
# Connection pool settings
limits=httpx.Limits(
max_keepalive_connections=20,
max_connections=100,
keepalive_expiry=30.0
),
# Default timeout (can be overridden per-request)
timeout=DEFAULT_TIMEOUT,
# SSL verification
verify=UNRAID_VERIFY_SSL,
# Connection pooling headers
headers={
"Connection": "keep-alive",
"User-Agent": f"UnraidMCPServer/{VERSION}"
}
)
async def get_http_client() -> httpx.AsyncClient:
"""Get or create shared HTTP client with connection pooling.
The client is protected by an asyncio lock to prevent concurrent creation.
If the existing client was closed (e.g., during shutdown), a new one is created.
Returns:
Singleton AsyncClient instance with connection pooling enabled
"""
@@ -82,26 +135,21 @@ async def get_http_client() -> httpx.AsyncClient:
async with _client_lock:
if _http_client is None or _http_client.is_closed:
_http_client = httpx.AsyncClient(
# Connection pool settings
limits=httpx.Limits(
max_keepalive_connections=20,
max_connections=100,
keepalive_expiry=30.0
),
# Default timeout (can be overridden per-request)
timeout=DEFAULT_TIMEOUT,
# SSL verification
verify=UNRAID_VERIFY_SSL,
# Connection pooling headers
headers={
"Connection": "keep-alive",
"User-Agent": f"UnraidMCPServer/{VERSION}"
}
)
_http_client = await _create_http_client()
logger.info("Created shared HTTP client with connection pooling (20 keepalive, 100 max connections)")
return _http_client
client = _http_client
# Verify client is still open after releasing the lock.
# In asyncio's cooperative model this is unlikely to fail, but guards
# against edge cases where close_http_client runs between yield points.
if client.is_closed:
async with _client_lock:
_http_client = await _create_http_client()
client = _http_client
logger.info("Re-created HTTP client after unexpected close")
return client
async def close_http_client() -> None:
@@ -175,12 +223,12 @@ async def make_graphql_request(
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
response_data = response.json()
if "errors" in response_data and response_data["errors"]:
if response_data.get("errors"):
error_details = "; ".join([err.get("message", str(err)) for err in response_data["errors"]])
# Check if this is an idempotent error that should be treated as success
if operation_context and operation_context.get('operation'):
operation = operation_context['operation']
if operation_context and operation_context.get("operation"):
operation = operation_context["operation"]
if is_idempotent_error(error_details, operation):
logger.warning(f"Idempotent operation '{operation}' - treating as success: {error_details}")
# Return a success response with the current state information
@@ -204,22 +252,7 @@ async def make_graphql_request(
raise ToolError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
logger.error(f"Request error occurred: {e}")
raise ToolError(f"Network connection error: {str(e)}") from e
raise ToolError(f"Network connection error: {e!s}") from e
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
raise ToolError(f"Invalid JSON response from Unraid API: {str(e)}") from e
def get_timeout_for_operation(operation_type: str = "default") -> httpx.Timeout:
"""Get appropriate timeout configuration for different operation types.
Args:
operation_type: Type of operation ('default', 'disk_operations')
Returns:
httpx.Timeout configuration appropriate for the operation
"""
if operation_type == "disk_operations":
return DISK_TIMEOUT
else:
return DEFAULT_TIMEOUT
raise ToolError(f"Invalid JSON response from Unraid API: {e!s}") from e

View File

@@ -6,6 +6,7 @@ the modular server implementation from unraid_mcp.server.
"""
import asyncio
import sys
async def shutdown_cleanup() -> None:
@@ -26,16 +27,20 @@ def main() -> None:
print("\nServer stopped by user")
try:
asyncio.run(shutdown_cleanup())
except RuntimeError:
# Event loop already closed, skip cleanup
pass
except RuntimeError as e:
if "event loop is closed" in str(e).lower() or "no running event loop" in str(e).lower():
pass # Expected during shutdown
else:
print(f"WARNING: Unexpected error during cleanup: {e}", file=sys.stderr)
except Exception as e:
print(f"Server failed to start: {e}")
try:
asyncio.run(shutdown_cleanup())
except RuntimeError:
# Event loop already closed, skip cleanup
pass
except RuntimeError as e:
if "event loop is closed" in str(e).lower() or "no running event loop" in str(e).lower():
pass # Expected during shutdown
else:
print(f"WARNING: Unexpected error during cleanup: {e}", file=sys.stderr)
raise

View File

@@ -29,6 +29,7 @@ from .tools.storage import register_storage_tool
from .tools.users import register_users_tool
from .tools.virtualization import register_vm_tool
# Initialize FastMCP instance
mcp = FastMCP(
name="Unraid MCP Server",
@@ -48,16 +49,20 @@ def register_all_modules() -> None:
logger.info("Subscription resources registered")
# Register all 10 consolidated tools
register_info_tool(mcp)
register_array_tool(mcp)
register_storage_tool(mcp)
register_docker_tool(mcp)
register_vm_tool(mcp)
register_notifications_tool(mcp)
register_rclone_tool(mcp)
register_users_tool(mcp)
register_keys_tool(mcp)
register_health_tool(mcp)
registrars = [
register_info_tool,
register_array_tool,
register_storage_tool,
register_docker_tool,
register_vm_tool,
register_notifications_tool,
register_rclone_tool,
register_users_tool,
register_keys_tool,
register_health_tool,
]
for registrar in registrars:
registrar(mcp)
logger.info("All 10 tools registered successfully - Server ready!")

View File

@@ -7,7 +7,6 @@ development and debugging purposes.
import asyncio
import json
import ssl
from datetime import datetime
from typing import Any
@@ -16,10 +15,11 @@ from fastmcp import FastMCP
from websockets.typing import Subprotocol
from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL, UNRAID_VERIFY_SSL
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.exceptions import ToolError
from .manager import subscription_manager
from .resources import ensure_subscriptions_started
from .utils import build_ws_ssl_context
def register_diagnostic_tools(mcp: FastMCP) -> None:
@@ -31,8 +31,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
@mcp.tool()
async def test_subscription_query(subscription_query: str) -> dict[str, Any]:
"""
Test a GraphQL subscription query directly to debug schema issues.
"""Test a GraphQL subscription query directly to debug schema issues.
Use this to find working subscription field names and structure.
Args:
@@ -49,15 +49,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
raise ToolError("UNRAID_API_URL is not configured")
ws_url = UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://") + "/graphql"
# Build SSL context for wss:// connections
ssl_context = None
if ws_url.startswith("wss://"):
if isinstance(UNRAID_VERIFY_SSL, str):
ssl_context = ssl.create_default_context(cafile=UNRAID_VERIFY_SSL)
elif UNRAID_VERIFY_SSL:
ssl_context = ssl.create_default_context()
else:
ssl_context = ssl._create_unverified_context()
ssl_context = build_ws_ssl_context(ws_url)
# Test connection
async with websockets.connect(
@@ -104,7 +96,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"query_tested": subscription_query
}
except asyncio.TimeoutError:
except TimeoutError:
return {
"success": True,
"response": "No immediate response (subscriptions may only send data on changes)",
@@ -121,8 +113,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
@mcp.tool()
async def diagnose_subscriptions() -> dict[str, Any]:
"""
Comprehensive diagnostic tool for subscription system.
"""Comprehensive diagnostic tool for subscription system.
Shows detailed status, connection states, errors, and troubleshooting info.
Returns:
@@ -163,14 +155,14 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
# Calculate WebSocket URL
if UNRAID_API_URL:
if UNRAID_API_URL.startswith('https://'):
ws_url = 'wss://' + UNRAID_API_URL[len('https://'):]
elif UNRAID_API_URL.startswith('http://'):
ws_url = 'ws://' + UNRAID_API_URL[len('http://'):]
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://"):]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://"):]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith('/graphql'):
ws_url = ws_url.rstrip('/') + '/graphql'
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
diagnostic_info["environment"]["websocket_url"] = ws_url
# Analyze issues
@@ -222,6 +214,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
except Exception as e:
logger.error(f"[DIAGNOSTIC] Failed to generate diagnostics: {e}")
raise ToolError(f"Failed to generate diagnostics: {str(e)}") from e
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e
logger.info("Subscription diagnostic tools registered successfully")

View File

@@ -8,7 +8,6 @@ error handling, reconnection logic, and authentication.
import asyncio
import json
import os
import ssl
from datetime import datetime
from typing import Any
@@ -16,8 +15,9 @@ import websockets
from websockets.typing import Subprotocol
from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL, UNRAID_VERIFY_SSL
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context
class SubscriptionManager:
@@ -141,28 +141,20 @@ class SubscriptionManager:
if not UNRAID_API_URL:
raise ValueError("UNRAID_API_URL is not configured")
if UNRAID_API_URL.startswith('https://'):
ws_url = 'wss://' + UNRAID_API_URL[len('https://'):]
elif UNRAID_API_URL.startswith('http://'):
ws_url = 'ws://' + UNRAID_API_URL[len('http://'):]
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://"):]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://"):]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith('/graphql'):
ws_url = ws_url.rstrip('/') + '/graphql'
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
logger.debug(f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}")
# Build SSL context for wss:// connections
ssl_context = None
if ws_url.startswith('wss://'):
if isinstance(UNRAID_VERIFY_SSL, str):
ssl_context = ssl.create_default_context(cafile=UNRAID_VERIFY_SSL)
elif UNRAID_VERIFY_SSL:
ssl_context = ssl.create_default_context()
else:
ssl_context = ssl._create_unverified_context()
ssl_context = build_ws_ssl_context(ws_url)
# Connection with timeout
connect_timeout = 10
@@ -213,7 +205,7 @@ class SubscriptionManager:
init_data = json.loads(init_raw)
logger.debug(f"[PROTOCOL:{subscription_name}] Received init response: {init_data.get('type')}")
except json.JSONDecodeError as e:
init_preview = init_raw[:200] if isinstance(init_raw, str) else init_raw[:200].decode('utf-8', errors='replace')
init_preview = init_raw[:200] if isinstance(init_raw, str) else init_raw[:200].decode("utf-8", errors="replace")
logger.error(f"[PROTOCOL:{subscription_name}] Failed to decode init response: {init_preview}...")
self.last_error[subscription_name] = f"Invalid JSON in init response: {e}"
break
@@ -223,7 +215,7 @@ class SubscriptionManager:
logger.info(f"[PROTOCOL:{subscription_name}] Connection acknowledged successfully")
self.connection_states[subscription_name] = "authenticated"
elif init_data.get("type") == "connection_error":
error_payload = init_data.get('payload', {})
error_payload = init_data.get("payload", {})
logger.error(f"[AUTH:{subscription_name}] Authentication failed: {error_payload}")
self.last_error[subscription_name] = f"Authentication error: {error_payload}"
self.connection_states[subscription_name] = "auth_failed"
@@ -259,7 +251,7 @@ class SubscriptionManager:
try:
data = json.loads(message)
message_count += 1
message_type = data.get('type', 'unknown')
message_type = data.get("type", "unknown")
logger.debug(f"[DATA:{subscription_name}] Message #{message_count}: {message_type}")
@@ -288,7 +280,7 @@ class SubscriptionManager:
await websocket.send(json.dumps({"type": "pong"}))
elif data.get("type") == "error":
error_payload = data.get('payload', {})
error_payload = data.get("payload", {})
logger.error(f"[SUBSCRIPTION:{subscription_name}] Subscription error: {error_payload}")
self.last_error[subscription_name] = f"Subscription error: {error_payload}"
self.connection_states[subscription_name] = "error"
@@ -305,15 +297,15 @@ class SubscriptionManager:
logger.debug(f"[PROTOCOL:{subscription_name}] Unhandled message type: {message_type}")
except json.JSONDecodeError as e:
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode('utf-8', errors='replace')
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode("utf-8", errors="replace")
logger.error(f"[PROTOCOL:{subscription_name}] Failed to decode message: {msg_preview}...")
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e:
logger.error(f"[DATA:{subscription_name}] Error processing message: {e}")
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode('utf-8', errors='replace')
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode("utf-8", errors="replace")
logger.debug(f"[DATA:{subscription_name}] Raw message: {msg_preview}...")
except asyncio.TimeoutError:
except TimeoutError:
error_msg = "Connection or authentication timeout"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
@@ -353,9 +345,8 @@ class SubscriptionManager:
age_seconds = (datetime.now() - data.last_updated).total_seconds()
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
return data.data
else:
logger.debug(f"[RESOURCE:{resource_name}] No data available")
return None
logger.debug(f"[RESOURCE:{resource_name}] No data available")
return None
def list_active_subscriptions(self) -> list[str]:
"""List all active subscriptions."""

View File

@@ -13,6 +13,7 @@ from fastmcp import FastMCP
from ..config.logging import logger
from .manager import subscription_manager
# Global flag to track subscription startup
_subscriptions_started = False

View File

@@ -0,0 +1,27 @@
"""Shared utilities for the subscription system."""
import ssl as _ssl
from ..config.settings import UNRAID_VERIFY_SSL
def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
"""Build an SSL context for WebSocket connections when using wss://.
Args:
ws_url: The WebSocket URL to connect to.
Returns:
An SSLContext configured per UNRAID_VERIFY_SSL, or None for non-TLS URLs.
"""
if not ws_url.startswith("wss://"):
return None
if isinstance(UNRAID_VERIFY_SSL, str):
return _ssl.create_default_context(cafile=UNRAID_VERIFY_SSL)
if UNRAID_VERIFY_SSL:
return _ssl.create_default_context()
# Explicitly disable verification (equivalent to verify=False)
ctx = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = _ssl.CERT_NONE
return ctx

View File

@@ -12,9 +12,10 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"parity_history": """
query GetParityHistory {
"parity_status": """
query GetParityStatus {
array { parityCheckStatus { progress speed errors } }
}
""",
@@ -80,10 +81,11 @@ MUTATIONS: dict[str, str] = {
DESTRUCTIVE_ACTIONS = {"start", "stop", "shutdown", "reboot"}
DISK_ACTIONS = {"mount_disk", "unmount_disk", "clear_stats"}
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
ARRAY_ACTIONS = Literal[
"start", "stop",
"parity_start", "parity_pause", "parity_resume", "parity_cancel", "parity_history",
"parity_start", "parity_pause", "parity_resume", "parity_cancel", "parity_status",
"mount_disk", "unmount_disk", "clear_stats",
"shutdown", "reboot",
]
@@ -108,16 +110,15 @@ def register_array_tool(mcp: FastMCP) -> None:
parity_pause - Pause running parity check
parity_resume - Resume paused parity check
parity_cancel - Cancel running parity check
parity_history - Get parity check status/history
parity_status - Get current parity check status
mount_disk - Mount an array disk (requires disk_id)
unmount_disk - Unmount an array disk (requires disk_id)
clear_stats - Clear disk statistics (requires disk_id)
shutdown - Shut down the server (destructive, requires confirm=True)
reboot - Reboot the server (destructive, requires confirm=True)
"""
all_actions = set(QUERIES) | set(MUTATIONS)
if action not in all_actions:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
if action not in ALL_ACTIONS:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(
@@ -156,6 +157,6 @@ def register_array_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_array action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute array/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute array/{action}: {e!s}") from e
logger.info("Array tool registered successfully")

View File

@@ -13,6 +13,7 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"list": """
query ListDockerContainers {
@@ -98,7 +99,8 @@ MUTATIONS: dict[str, str] = {
}
DESTRUCTIVE_ACTIONS = {"remove"}
CONTAINER_ACTIONS = {"start", "stop", "restart", "pause", "unpause", "remove", "update", "details", "logs"}
_ACTIONS_REQUIRING_CONTAINER_ID = {"start", "stop", "restart", "pause", "unpause", "remove", "update", "details", "logs"}
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"restart"}
DOCKER_ACTIONS = Literal[
"list", "details", "start", "stop", "restart", "pause", "unpause",
@@ -107,7 +109,7 @@ DOCKER_ACTIONS = Literal[
]
# Docker container IDs: 64 hex chars + optional suffix (e.g., ":local")
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$")
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
def find_container_by_identifier(
@@ -175,6 +177,7 @@ def register_docker_tool(mcp: FastMCP) -> None:
action: DOCKER_ACTIONS,
container_id: str | None = None,
network_id: str | None = None,
*,
confirm: bool = False,
tail_lines: int = 100,
) -> dict[str, Any]:
@@ -197,14 +200,13 @@ def register_docker_tool(mcp: FastMCP) -> None:
port_conflicts - Check for port conflicts
check_updates - Check which containers have updates available
"""
all_actions = set(QUERIES) | set(MUTATIONS) | {"restart"}
if action not in all_actions:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
if action not in ALL_ACTIONS:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
if action in CONTAINER_ACTIONS and not container_id:
if action in _ACTIONS_REQUIRING_CONTAINER_ID and not container_id:
raise ToolError(f"container_id is required for '{action}' action")
if action == "network_details" and not network_id:
@@ -327,6 +329,6 @@ def register_docker_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_docker action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute docker/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute docker/{action}: {e!s}") from e
logger.info("Docker tool registered successfully")

View File

@@ -21,12 +21,24 @@ from ..config.settings import (
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
# Severity ordering: only upgrade, never downgrade
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
def _server_info() -> dict[str, Any]:
"""Return the standard server info block used in health responses."""
return {
"name": "Unraid MCP Server",
"version": VERSION,
"transport": UNRAID_MCP_TRANSPORT,
"host": UNRAID_MCP_HOST,
"port": UNRAID_MCP_PORT,
}
def register_health_tool(mcp: FastMCP) -> None:
"""Register the unraid_health tool with the FastMCP instance."""
@@ -71,7 +83,7 @@ def register_health_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_health action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute health/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute health/{action}: {e!s}") from e
logger.info("Health tool registered successfully")
@@ -108,15 +120,9 @@ async def _comprehensive_check() -> dict[str, Any]:
health_info: dict[str, Any] = {
"status": "healthy",
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
"api_latency_ms": api_latency,
"server": {
"name": "Unraid MCP Server",
"version": VERSION,
"transport": UNRAID_MCP_TRANSPORT,
"host": UNRAID_MCP_HOST,
"port": UNRAID_MCP_PORT,
},
"server": _server_info(),
}
if not data:
@@ -201,15 +207,9 @@ async def _comprehensive_check() -> dict[str, Any]:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
"error": str(e),
"server": {
"name": "Unraid MCP Server",
"version": VERSION,
"transport": UNRAID_MCP_TRANSPORT,
"host": UNRAID_MCP_HOST,
"port": UNRAID_MCP_PORT,
},
"server": _server_info(),
}
@@ -225,7 +225,7 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
connection_issues: list[dict[str, Any]] = []
diagnostic_info: dict[str, Any] = {
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
"environment": {
"auto_start_enabled": subscription_manager.auto_start_enabled,
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
@@ -258,7 +258,7 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
except ImportError:
return {
"error": "Subscription modules not available",
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
}
except Exception as e:
raise ToolError(f"Failed to generate diagnostics: {str(e)}") from e
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e

View File

@@ -12,6 +12,7 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
# Pre-built queries keyed by action name
QUERIES: dict[str, str] = {
"overview": """
@@ -162,6 +163,10 @@ INFO_ACTIONS = Literal[
"ups_devices", "ups_device", "ups_config",
]
assert set(QUERIES.keys()) == set(INFO_ACTIONS.__args__), (
"QUERIES keys and INFO_ACTIONS are out of sync"
)
def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
"""Process raw system info into summary + details."""
@@ -179,7 +184,7 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
cpu = raw_info["cpu"]
summary["cpu"] = (
f"{cpu.get('manufacturer', '')} {cpu.get('brand', '')} "
f"({cpu.get('cores')} cores, {cpu.get('threads')} threads)"
f"({cpu.get('cores', '?')} cores, {cpu.get('threads', '?')} threads)"
)
if raw_info.get("memory") and raw_info["memory"].get("layout"):
@@ -227,27 +232,31 @@ def _analyze_disk_health(disks: list[dict[str, Any]]) -> dict[str, int]:
return counts
def _format_kb(k: Any) -> str:
"""Format kilobyte values into human-readable sizes."""
if k is None:
return "N/A"
try:
k = int(k)
except (ValueError, TypeError):
return "N/A"
if k >= 1024 * 1024 * 1024:
return f"{k / (1024 * 1024 * 1024):.2f} TB"
if k >= 1024 * 1024:
return f"{k / (1024 * 1024):.2f} GB"
if k >= 1024:
return f"{k / 1024:.2f} MB"
return f"{k} KB"
def _process_array_status(raw: dict[str, Any]) -> dict[str, Any]:
"""Process raw array data into summary + details."""
def format_kb(k: Any) -> str:
if k is None:
return "N/A"
k = int(k)
if k >= 1024 * 1024 * 1024:
return f"{k / (1024 * 1024 * 1024):.2f} TB"
if k >= 1024 * 1024:
return f"{k / (1024 * 1024):.2f} GB"
if k >= 1024:
return f"{k / 1024:.2f} MB"
return f"{k} KB"
summary: dict[str, Any] = {"state": raw.get("state")}
if raw.get("capacity") and raw["capacity"].get("kilobytes"):
kb = raw["capacity"]["kilobytes"]
summary["capacity_total"] = format_kb(kb.get("total"))
summary["capacity_used"] = format_kb(kb.get("used"))
summary["capacity_free"] = format_kb(kb.get("free"))
summary["capacity_total"] = _format_kb(kb.get("total"))
summary["capacity_used"] = _format_kb(kb.get("used"))
summary["capacity_free"] = _format_kb(kb.get("free"))
summary["num_parity_disks"] = len(raw.get("parities", []))
summary["num_data_disks"] = len(raw.get("disks", []))
@@ -320,81 +329,73 @@ def register_info_tool(mcp: FastMCP) -> None:
if action == "ups_device":
variables = {"id": device_id}
# Lookup tables for common response patterns
# Simple dict actions: action -> GraphQL response key
dict_actions: dict[str, str] = {
"network": "network",
"registration": "registration",
"connect": "connect",
"variables": "vars",
"metrics": "metrics",
"config": "config",
"owner": "owner",
"flash": "flash",
"ups_device": "upsDeviceById",
"ups_config": "upsConfiguration",
}
# List-wrapped actions: action -> (GraphQL response key, output key)
list_actions: dict[str, tuple[str, str]] = {
"services": ("services", "services"),
"servers": ("servers", "servers"),
"ups_devices": ("upsDevices", "ups_devices"),
}
try:
logger.info(f"Executing unraid_info action={action}")
data = await make_graphql_request(query, variables)
# Action-specific response processing
# Special-case actions with custom processing
if action == "overview":
raw = data.get("info", {})
raw = data.get("info") or {}
if not raw:
raise ToolError("No system info returned from Unraid API")
return _process_system_info(raw)
if action == "array":
raw = data.get("array", {})
raw = data.get("array") or {}
if not raw:
raise ToolError("No array information returned from Unraid API")
return _process_array_status(raw)
if action == "network":
return dict(data.get("network", {}))
if action == "registration":
return dict(data.get("registration", {}))
if action == "connect":
return dict(data.get("connect", {}))
if action == "variables":
return dict(data.get("vars", {}))
if action == "metrics":
return dict(data.get("metrics", {}))
if action == "services":
services = data.get("services", [])
return {"services": list(services) if isinstance(services, list) else []}
if action == "display":
info = data.get("info", {})
return dict(info.get("display", {}))
if action == "config":
return dict(data.get("config", {}))
info = data.get("info") or {}
return dict(info.get("display") or {})
if action == "online":
return {"online": data.get("online")}
if action == "owner":
return dict(data.get("owner", {}))
if action == "settings":
settings = data.get("settings", {})
if settings and settings.get("unified"):
values = settings["unified"].get("values", {})
return dict(values) if isinstance(values, dict) else {"raw": values}
return {}
settings = data.get("settings") or {}
if not settings:
raise ToolError("No settings data returned from Unraid API. Check API permissions.")
if not settings.get("unified"):
logger.warning(f"Settings returned unexpected structure: {settings.keys()}")
raise ToolError(f"Unexpected settings structure. Expected 'unified' key, got: {list(settings.keys())}")
values = settings["unified"].get("values") or {}
return dict(values) if isinstance(values, dict) else {"raw": values}
if action == "server":
return data
if action == "servers":
servers = data.get("servers", [])
return {"servers": list(servers) if isinstance(servers, list) else []}
# Simple dict-returning actions
if action in dict_actions:
return dict(data.get(dict_actions[action]) or {})
if action == "flash":
return dict(data.get("flash", {}))
if action == "ups_devices":
devices = data.get("upsDevices", [])
return {"ups_devices": list(devices) if isinstance(devices, list) else []}
if action == "ups_device":
return dict(data.get("upsDeviceById", {}))
if action == "ups_config":
return dict(data.get("upsConfiguration", {}))
# List-wrapped actions
if action in list_actions:
response_key, output_key = list_actions[action]
items = data.get(response_key) or []
return {output_key: list(items) if isinstance(items, list) else []}
raise ToolError(f"Unhandled action '{action}' — this is a bug")
@@ -402,6 +403,6 @@ def register_info_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_info action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute info/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute info/{action}: {e!s}") from e
logger.info("Info tool registered successfully")

View File

@@ -12,6 +12,7 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"list": """
query ListApiKeys {
@@ -90,7 +91,7 @@ def register_keys_tool(mcp: FastMCP) -> None:
if not key_id:
raise ToolError("key_id is required for 'get' action")
data = await make_graphql_request(QUERIES["get"], {"id": key_id})
return dict(data.get("apiKey", {}))
return dict(data.get("apiKey") or {})
if action == "create":
if not name:
@@ -144,6 +145,6 @@ def register_keys_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_keys action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute keys/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute keys/{action}: {e!s}") from e
logger.info("Keys tool registered successfully")

View File

@@ -12,6 +12,7 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"overview": """
query GetNotificationsOverview {
@@ -124,8 +125,8 @@ def register_notifications_tool(mcp: FastMCP) -> None:
if action == "overview":
data = await make_graphql_request(QUERIES["overview"])
notifications = data.get("notifications", {})
return dict(notifications.get("overview", {}))
notifications = data.get("notifications") or {}
return dict(notifications.get("overview") or {})
if action == "list":
filter_vars: dict[str, Any] = {
@@ -200,6 +201,6 @@ def register_notifications_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_notifications action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute notifications/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute notifications/{action}: {e!s}") from e
logger.info("Notifications tool registered successfully")

View File

@@ -12,6 +12,7 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"list_remotes": """
query ListRCloneRemotes {
@@ -39,6 +40,7 @@ MUTATIONS: dict[str, str] = {
}
DESTRUCTIVE_ACTIONS = {"delete_remote"}
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
RCLONE_ACTIONS = Literal[
"list_remotes", "config_form", "create_remote", "delete_remote",
@@ -64,9 +66,8 @@ def register_rclone_tool(mcp: FastMCP) -> None:
create_remote - Create a new remote (requires name, provider_type, config_data)
delete_remote - Delete a remote (requires name, confirm=True)
"""
all_actions = set(QUERIES) | set(MUTATIONS)
if action not in all_actions:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
if action not in ALL_ACTIONS:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
@@ -129,6 +130,6 @@ def register_rclone_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_rclone action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute rclone/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute rclone/{action}: {e!s}") from e
logger.info("RClone tool registered successfully")

View File

@@ -4,7 +4,7 @@ Provides the `unraid_storage` tool with 6 actions for shares, physical disks,
unassigned devices, log files, and log content retrieval.
"""
import posixpath
from pathlib import Path
from typing import Any, Literal
from fastmcp import FastMCP
@@ -102,8 +102,8 @@ def register_storage_tool(mcp: FastMCP) -> None:
if action == "logs":
if not log_path:
raise ToolError("log_path is required for 'logs' action")
# Normalize path to prevent traversal attacks (e.g. /var/log/../../etc/shadow)
normalized = posixpath.normpath(log_path)
# Resolve path to prevent traversal attacks (e.g. /var/log/../../etc/shadow)
normalized = str(Path(log_path).resolve())
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
raise ToolError(
f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}. "
@@ -143,8 +143,8 @@ def register_storage_tool(mcp: FastMCP) -> None:
"serial_number": raw.get("serialNum"),
"size_formatted": format_bytes(raw.get("size")),
"temperature": (
f"{raw.get('temperature')}C"
if raw.get("temperature")
f"{raw['temperature']}\u00b0C"
if raw.get("temperature") is not None
else "N/A"
),
}
@@ -159,7 +159,7 @@ def register_storage_tool(mcp: FastMCP) -> None:
return {"log_files": list(files) if isinstance(files, list) else []}
if action == "logs":
return dict(data.get("logFile", {}))
return dict(data.get("logFile") or {})
raise ToolError(f"Unhandled action '{action}' — this is a bug")

View File

@@ -12,6 +12,7 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"me": """
query GetMe {
@@ -158,6 +159,6 @@ def register_users_tool(mcp: FastMCP) -> None:
raise
except Exception as e:
logger.error(f"Error in unraid_users action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute users/{action}: {str(e)}") from e
raise ToolError(f"Failed to execute users/{action}: {e!s}") from e
logger.info("Users tool registered successfully")

View File

@@ -12,17 +12,13 @@ from ..config.logging import logger
from ..core.client import make_graphql_request
from ..core.exceptions import ToolError
QUERIES: dict[str, str] = {
"list": """
query ListVMs {
vms { id domains { id name state uuid } }
}
""",
"details": """
query GetVmDetails {
vms { domains { id name state uuid } }
}
""",
}
MUTATIONS: dict[str, str] = {
@@ -49,15 +45,9 @@ MUTATIONS: dict[str, str] = {
""",
}
# Map action names to their GraphQL field names
# Map action names to GraphQL field names (only where they differ)
_MUTATION_FIELDS: dict[str, str] = {
"start": "start",
"stop": "stop",
"pause": "pause",
"resume": "resume",
"force_stop": "forceStop",
"reboot": "reboot",
"reset": "reset",
}
DESTRUCTIVE_ACTIONS = {"force_stop", "reset"}
@@ -90,7 +80,7 @@ def register_vm_tool(mcp: FastMCP) -> None:
reboot - Reboot a VM (requires vm_id)
reset - Reset a VM (requires vm_id, confirm=True)
"""
all_actions = set(QUERIES) | set(MUTATIONS)
all_actions = set(QUERIES) | set(MUTATIONS) | {"details"}
if action not in all_actions:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
@@ -103,39 +93,40 @@ def register_vm_tool(mcp: FastMCP) -> None:
try:
logger.info(f"Executing unraid_vm action={action}")
if action == "list":
if action in ("list", "details"):
data = await make_graphql_request(QUERIES["list"])
if data.get("vms"):
vms = data["vms"].get("domains") or data["vms"].get("domain")
if vms:
return {"vms": list(vms) if isinstance(vms, list) else []}
return {"vms": []}
if action == "details":
data = await make_graphql_request(QUERIES["details"])
if data.get("vms"):
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
if isinstance(vms, dict):
vms = [vms]
if action == "list":
return {"vms": vms}
# details: find specific VM
for vm in vms:
if (
vm.get("uuid") == vm_id
or vm.get("id") == vm_id
or vm.get("name") == vm_id
):
return dict(vm) if isinstance(vm, dict) else {}
return dict(vm)
available = [
f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms
]
raise ToolError(
f"VM '{vm_id}' not found. Available: {', '.join(available)}"
)
raise ToolError("No VM data returned from server")
if action == "details":
raise ToolError("No VM data returned from server")
return {"vms": []}
# Mutations
if action in MUTATIONS:
data = await make_graphql_request(
MUTATIONS[action], {"id": vm_id}
)
field = _MUTATION_FIELDS[action]
field = _MUTATION_FIELDS.get(action, action)
if data.get("vm") and field in data["vm"]:
return {
"success": data["vm"][field],