mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-01 16:04:24 -08:00
refactor: comprehensive code review fixes across 31 files
Addresses all critical, high, medium, and low issues from full codebase review. 494 tests pass, ruff clean, ty type-check clean. Security: - Add tool_error_handler context manager (exceptions.py) — standardised error handling, eliminates 11 bare except-reraise patterns - Remove unused exception subclasses (ConfigurationError, UnraidAPIError, SubscriptionError, ValidationError, IdempotentOperationError) - Harden GraphQL subscription query validator with allow-list and forbidden-keyword regex (diagnostics.py) - Add input validation for rclone create_remote config_data: injection, path-traversal, and key-count limits (rclone.py) - Validate notifications importance enum before GraphQL request (notifications.py) - Sanitise HTTP/network/JSON error messages — no raw exception strings leaked to clients (client.py) - Strip path/creds from displayed API URL via _safe_display_url (health.py) - Enable Ruff S (bandit) rule category in pyproject.toml - Harden container mutations to strict-only matching — no fuzzy/substring for destructive operations (docker.py) Performance: - Token-bucket rate limiter (90 tokens, 9 req/s) with 429 retry backoff (client.py) - Lazy asyncio.Lock init via _get_client_lock() — fixes event-loop module-load crash (client.py) - Double-checked locking in get_http_client() for fast-path (client.py) - Short hex container ID fast-path skips list fetch (docker.py) - Cap resource_data log content to 1 MB / 5,000 lines (manager.py) - Reset reconnect counter after 30 s stable connection (manager.py) - Move tail_lines validation to module level; enforce 10,000 line cap (storage.py, docker.py) - force_terminal=True removed from logging RichHandler (logging.py) Architecture: - Register diagnostic tools in server startup (server.py) - Move ALL_ACTIONS computation to module level in all tools - Consolidate format_kb / format_bytes into shared core/utils.py - Add _safe_get() helper in core/utils.py for nested dict traversal - Extract _analyze_subscription_status() from health.py diagnose handler - Validate required config at startup — fail fast with CRITICAL log (server.py) Code quality: - Remove ~90 lines of dead Rich formatting helpers from logging.py - Remove dead self.websocket attribute from SubscriptionManager - Remove dead setup_uvicorn_logging() wrapper - Move _VALID_IMPORTANCE to module level (N806 fix) - Add slots=True to all three dataclasses (SubscriptionData, SystemHealth, APIResponse) - Fix None rendering as literal "None" string in info.py summaries - Change fuzzy-match log messages from INFO to DEBUG (docker.py) - UTC-aware datetimes throughout (manager.py, diagnostics.py) Infrastructure: - Upgrade base image python:3.11-slim → python:3.12-slim (Dockerfile) - Add non-root appuser (UID/GID 1000) with HEALTHCHECK (Dockerfile) - Add read_only, cap_drop: ALL, tmpfs /tmp to docker-compose.yml - Single-source version via importlib.metadata (pyproject.toml → __init__.py) - Add open_timeout to all websockets.connect() calls Tests: - Update error message matchers to match sanitised messages (test_client.py) - Fix patch targets for UNRAID_API_URL → utils module (test_subscriptions.py) - Fix importance="info" → importance="normal" (test_notifications.py, http_layer) - Fix naive datetime fixtures → UTC-aware (test_subscriptions.py) Co-authored-by: Claude <claude@anthropic.com>
This commit is contained in:
27
Dockerfile
27
Dockerfile
@@ -1,5 +1,5 @@
|
||||
# Use an official Python runtime as a parent image
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
@@ -7,13 +7,22 @@ WORKDIR /app
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml .
|
||||
COPY uv.lock .
|
||||
COPY README.md .
|
||||
# Create non-root user with home directory and give ownership of /app
|
||||
RUN groupadd --gid 1000 appuser && \
|
||||
useradd --uid 1000 --gid 1000 --create-home --shell /bin/false appuser && \
|
||||
chown appuser:appuser /app
|
||||
|
||||
# Copy dependency files (owned by appuser via --chown)
|
||||
COPY --chown=appuser:appuser pyproject.toml .
|
||||
COPY --chown=appuser:appuser uv.lock .
|
||||
COPY --chown=appuser:appuser README.md .
|
||||
COPY --chown=appuser:appuser LICENSE .
|
||||
|
||||
# Copy the source code
|
||||
COPY unraid_mcp/ ./unraid_mcp/
|
||||
COPY --chown=appuser:appuser unraid_mcp/ ./unraid_mcp/
|
||||
|
||||
# Switch to non-root user before installing dependencies
|
||||
USER appuser
|
||||
|
||||
# Install dependencies and the package
|
||||
RUN uv sync --frozen
|
||||
@@ -31,5 +40,9 @@ ENV UNRAID_API_KEY=""
|
||||
ENV UNRAID_VERIFY_SSL="true"
|
||||
ENV UNRAID_MCP_LOG_LEVEL="INFO"
|
||||
|
||||
# Run unraid-mcp-server.py when the container launches
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||
CMD ["python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:6970/mcp')"]
|
||||
|
||||
# Run unraid-mcp-server when the container launches
|
||||
CMD ["uv", "run", "unraid-mcp-server"]
|
||||
|
||||
@@ -5,6 +5,11 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
container_name: unraid-mcp
|
||||
restart: unless-stopped
|
||||
read_only: true
|
||||
cap_drop:
|
||||
- ALL
|
||||
tmpfs:
|
||||
- /tmp:noexec,nosuid,size=64m
|
||||
ports:
|
||||
# HostPort:ContainerPort (maps to UNRAID_MCP_PORT inside the container, default 6970)
|
||||
# Change the host port (left side) if 6970 is already in use on your host
|
||||
|
||||
@@ -77,7 +77,6 @@ dependencies = [
|
||||
"uvicorn[standard]>=0.35.0",
|
||||
"websockets>=15.0.1",
|
||||
"rich>=14.1.0",
|
||||
"pytz>=2025.2",
|
||||
]
|
||||
|
||||
# ============================================================================
|
||||
@@ -170,6 +169,8 @@ select = [
|
||||
"PERF",
|
||||
# Ruff-specific rules
|
||||
"RUF",
|
||||
# flake8-bandit (security)
|
||||
"S",
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by ruff formatter)
|
||||
@@ -285,7 +286,6 @@ dev = [
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"respx>=0.22.0",
|
||||
"types-pytz>=2025.2.0.20250809",
|
||||
"ty>=0.0.15",
|
||||
"ruff>=0.12.8",
|
||||
"build>=1.2.2",
|
||||
|
||||
@@ -158,43 +158,43 @@ class TestHttpErrorHandling:
|
||||
@respx.mock
|
||||
async def test_http_401_raises_tool_error(self) -> None:
|
||||
respx.post(API_URL).mock(return_value=httpx.Response(401, text="Unauthorized"))
|
||||
with pytest.raises(ToolError, match="HTTP error 401"):
|
||||
with pytest.raises(ToolError, match="Unraid API returned HTTP 401"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
@respx.mock
|
||||
async def test_http_403_raises_tool_error(self) -> None:
|
||||
respx.post(API_URL).mock(return_value=httpx.Response(403, text="Forbidden"))
|
||||
with pytest.raises(ToolError, match="HTTP error 403"):
|
||||
with pytest.raises(ToolError, match="Unraid API returned HTTP 403"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
@respx.mock
|
||||
async def test_http_500_raises_tool_error(self) -> None:
|
||||
respx.post(API_URL).mock(return_value=httpx.Response(500, text="Internal Server Error"))
|
||||
with pytest.raises(ToolError, match="HTTP error 500"):
|
||||
with pytest.raises(ToolError, match="Unraid API returned HTTP 500"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
@respx.mock
|
||||
async def test_http_503_raises_tool_error(self) -> None:
|
||||
respx.post(API_URL).mock(return_value=httpx.Response(503, text="Service Unavailable"))
|
||||
with pytest.raises(ToolError, match="HTTP error 503"):
|
||||
with pytest.raises(ToolError, match="Unraid API returned HTTP 503"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
@respx.mock
|
||||
async def test_network_connection_error(self) -> None:
|
||||
respx.post(API_URL).mock(side_effect=httpx.ConnectError("Connection refused"))
|
||||
with pytest.raises(ToolError, match="Network connection error"):
|
||||
with pytest.raises(ToolError, match="Network error connecting to Unraid API"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
@respx.mock
|
||||
async def test_network_timeout_error(self) -> None:
|
||||
respx.post(API_URL).mock(side_effect=httpx.ReadTimeout("Read timed out"))
|
||||
with pytest.raises(ToolError, match="Network connection error"):
|
||||
with pytest.raises(ToolError, match="Network error connecting to Unraid API"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
@respx.mock
|
||||
async def test_invalid_json_response(self) -> None:
|
||||
respx.post(API_URL).mock(return_value=httpx.Response(200, text="not json"))
|
||||
with pytest.raises(ToolError, match="Invalid JSON response"):
|
||||
with pytest.raises(ToolError, match="invalid response"):
|
||||
await make_graphql_request("query { online }")
|
||||
|
||||
|
||||
@@ -868,14 +868,14 @@ class TestNotificationsToolRequests:
|
||||
title="Test",
|
||||
subject="Sub",
|
||||
description="Desc",
|
||||
importance="info",
|
||||
importance="normal",
|
||||
)
|
||||
body = _extract_request_body(route.calls.last.request)
|
||||
assert "CreateNotification" in body["query"]
|
||||
inp = body["variables"]["input"]
|
||||
assert inp["title"] == "Test"
|
||||
assert inp["subject"] == "Sub"
|
||||
assert inp["importance"] == "INFO" # uppercased
|
||||
assert inp["importance"] == "NORMAL" # uppercased from "normal"
|
||||
|
||||
@respx.mock
|
||||
async def test_archive_sends_id_variable(self) -> None:
|
||||
@@ -1256,7 +1256,7 @@ class TestCrossCuttingConcerns:
|
||||
tool = make_tool_fn(
|
||||
"unraid_mcp.tools.info", "register_info_tool", "unraid_info"
|
||||
)
|
||||
with pytest.raises(ToolError, match="HTTP error 500"):
|
||||
with pytest.raises(ToolError, match="Unraid API returned HTTP 500"):
|
||||
await tool(action="online")
|
||||
|
||||
@respx.mock
|
||||
@@ -1268,7 +1268,7 @@ class TestCrossCuttingConcerns:
|
||||
tool = make_tool_fn(
|
||||
"unraid_mcp.tools.info", "register_info_tool", "unraid_info"
|
||||
)
|
||||
with pytest.raises(ToolError, match="Network connection error"):
|
||||
with pytest.raises(ToolError, match="Network error connecting to Unraid API"):
|
||||
await tool(action="online")
|
||||
|
||||
@respx.mock
|
||||
|
||||
@@ -7,7 +7,7 @@ data management without requiring a live Unraid server.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -83,7 +83,7 @@ SAMPLE_QUERY = "subscription { test { value } }"
|
||||
|
||||
# Shared patch targets
|
||||
_WS_CONNECT = "unraid_mcp.subscriptions.manager.websockets.connect"
|
||||
_API_URL = "unraid_mcp.subscriptions.manager.UNRAID_API_URL"
|
||||
_API_URL = "unraid_mcp.subscriptions.utils.UNRAID_API_URL"
|
||||
_API_KEY = "unraid_mcp.subscriptions.manager.UNRAID_API_KEY"
|
||||
_SSL_CTX = "unraid_mcp.subscriptions.manager.build_ws_ssl_context"
|
||||
_SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep"
|
||||
@@ -100,7 +100,7 @@ class TestSubscriptionManagerInit:
|
||||
mgr = SubscriptionManager()
|
||||
assert mgr.active_subscriptions == {}
|
||||
assert mgr.resource_data == {}
|
||||
assert mgr.websocket is None
|
||||
assert not hasattr(mgr, "websocket")
|
||||
|
||||
def test_default_auto_start_enabled(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
@@ -720,20 +720,20 @@ class TestWebSocketURLConstruction:
|
||||
|
||||
class TestResourceData:
|
||||
|
||||
def test_get_resource_data_returns_none_when_empty(self) -> None:
|
||||
async def test_get_resource_data_returns_none_when_empty(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
assert mgr.get_resource_data("nonexistent") is None
|
||||
assert await mgr.get_resource_data("nonexistent") is None
|
||||
|
||||
def test_get_resource_data_returns_stored_data(self) -> None:
|
||||
async def test_get_resource_data_returns_stored_data(self) -> None:
|
||||
from unraid_mcp.core.types import SubscriptionData
|
||||
|
||||
mgr = SubscriptionManager()
|
||||
mgr.resource_data["test"] = SubscriptionData(
|
||||
data={"key": "value"},
|
||||
last_updated=datetime.now(),
|
||||
last_updated=datetime.now(UTC),
|
||||
subscription_type="test",
|
||||
)
|
||||
result = mgr.get_resource_data("test")
|
||||
result = await mgr.get_resource_data("test")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_list_active_subscriptions_empty(self) -> None:
|
||||
@@ -755,46 +755,46 @@ class TestResourceData:
|
||||
|
||||
class TestSubscriptionStatus:
|
||||
|
||||
def test_status_includes_all_configured_subscriptions(self) -> None:
|
||||
async def test_status_includes_all_configured_subscriptions(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
status = mgr.get_subscription_status()
|
||||
status = await mgr.get_subscription_status()
|
||||
for name in mgr.subscription_configs:
|
||||
assert name in status
|
||||
|
||||
def test_status_default_connection_state(self) -> None:
|
||||
async def test_status_default_connection_state(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
status = mgr.get_subscription_status()
|
||||
status = await mgr.get_subscription_status()
|
||||
for sub_status in status.values():
|
||||
assert sub_status["runtime"]["connection_state"] == "not_started"
|
||||
|
||||
def test_status_shows_active_flag(self) -> None:
|
||||
async def test_status_shows_active_flag(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
mgr.active_subscriptions["logFileSubscription"] = MagicMock()
|
||||
status = mgr.get_subscription_status()
|
||||
status = await mgr.get_subscription_status()
|
||||
assert status["logFileSubscription"]["runtime"]["active"] is True
|
||||
|
||||
def test_status_shows_data_availability(self) -> None:
|
||||
async def test_status_shows_data_availability(self) -> None:
|
||||
from unraid_mcp.core.types import SubscriptionData
|
||||
|
||||
mgr = SubscriptionManager()
|
||||
mgr.resource_data["logFileSubscription"] = SubscriptionData(
|
||||
data={"log": "content"},
|
||||
last_updated=datetime.now(),
|
||||
last_updated=datetime.now(UTC),
|
||||
subscription_type="logFileSubscription",
|
||||
)
|
||||
status = mgr.get_subscription_status()
|
||||
status = await mgr.get_subscription_status()
|
||||
assert status["logFileSubscription"]["data"]["available"] is True
|
||||
|
||||
def test_status_shows_error_info(self) -> None:
|
||||
async def test_status_shows_error_info(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
mgr.last_error["logFileSubscription"] = "Test error message"
|
||||
status = mgr.get_subscription_status()
|
||||
status = await mgr.get_subscription_status()
|
||||
assert status["logFileSubscription"]["runtime"]["last_error"] == "Test error message"
|
||||
|
||||
def test_status_reconnect_attempts_tracked(self) -> None:
|
||||
async def test_status_reconnect_attempts_tracked(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
mgr.reconnect_attempts["logFileSubscription"] = 3
|
||||
status = mgr.get_subscription_status()
|
||||
status = await mgr.get_subscription_status()
|
||||
assert status["logFileSubscription"]["runtime"]["reconnect_attempts"] == 3
|
||||
|
||||
|
||||
|
||||
@@ -384,10 +384,16 @@ class TestVmQueries:
|
||||
errors = _validate_operation(schema, QUERIES["list"])
|
||||
assert not errors, f"list query validation failed: {errors}"
|
||||
|
||||
def test_details_query(self, schema: GraphQLSchema) -> None:
|
||||
from unraid_mcp.tools.virtualization import QUERIES
|
||||
|
||||
errors = _validate_operation(schema, QUERIES["details"])
|
||||
assert not errors, f"details query validation failed: {errors}"
|
||||
|
||||
def test_all_vm_queries_covered(self, schema: GraphQLSchema) -> None:
|
||||
from unraid_mcp.tools.virtualization import QUERIES
|
||||
|
||||
assert set(QUERIES.keys()) == {"list"}
|
||||
assert set(QUERIES.keys()) == {"list", "details"}
|
||||
|
||||
|
||||
class TestVmMutations:
|
||||
|
||||
@@ -274,7 +274,7 @@ class TestMakeGraphQLRequestErrors:
|
||||
|
||||
with (
|
||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||
pytest.raises(ToolError, match="HTTP error 401"),
|
||||
pytest.raises(ToolError, match="Unraid API returned HTTP 401"),
|
||||
):
|
||||
await make_graphql_request("{ info }")
|
||||
|
||||
@@ -292,7 +292,7 @@ class TestMakeGraphQLRequestErrors:
|
||||
|
||||
with (
|
||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||
pytest.raises(ToolError, match="HTTP error 500"),
|
||||
pytest.raises(ToolError, match="Unraid API returned HTTP 500"),
|
||||
):
|
||||
await make_graphql_request("{ info }")
|
||||
|
||||
@@ -310,7 +310,7 @@ class TestMakeGraphQLRequestErrors:
|
||||
|
||||
with (
|
||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||
pytest.raises(ToolError, match="HTTP error 503"),
|
||||
pytest.raises(ToolError, match="Unraid API returned HTTP 503"),
|
||||
):
|
||||
await make_graphql_request("{ info }")
|
||||
|
||||
@@ -320,7 +320,7 @@ class TestMakeGraphQLRequestErrors:
|
||||
|
||||
with (
|
||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||
pytest.raises(ToolError, match="Network connection error"),
|
||||
pytest.raises(ToolError, match="Network error connecting to Unraid API"),
|
||||
):
|
||||
await make_graphql_request("{ info }")
|
||||
|
||||
@@ -330,7 +330,7 @@ class TestMakeGraphQLRequestErrors:
|
||||
|
||||
with (
|
||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||
pytest.raises(ToolError, match="Network connection error"),
|
||||
pytest.raises(ToolError, match="Network error connecting to Unraid API"),
|
||||
):
|
||||
await make_graphql_request("{ info }")
|
||||
|
||||
@@ -344,7 +344,7 @@ class TestMakeGraphQLRequestErrors:
|
||||
|
||||
with (
|
||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||
pytest.raises(ToolError, match="Invalid JSON response"),
|
||||
pytest.raises(ToolError, match="invalid response.*not valid JSON"),
|
||||
):
|
||||
await make_graphql_request("{ info }")
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestNotificationsActions:
|
||||
title="Test",
|
||||
subject="Test Subject",
|
||||
description="Test Desc",
|
||||
importance="info",
|
||||
importance="normal",
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from conftest import make_tool_fn
|
||||
|
||||
from unraid_mcp.core.exceptions import ToolError
|
||||
from unraid_mcp.tools.storage import format_bytes
|
||||
from unraid_mcp.core.utils import format_bytes
|
||||
|
||||
|
||||
# --- Unit tests for helpers ---
|
||||
|
||||
@@ -4,4 +4,10 @@ A modular MCP (Model Context Protocol) server that provides tools to interact
|
||||
with an Unraid server's GraphQL API.
|
||||
"""
|
||||
|
||||
__version__ = "0.2.0"
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
|
||||
try:
|
||||
__version__ = version("unraid-mcp")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "0.0.0"
|
||||
|
||||
@@ -5,16 +5,10 @@ that cap at 10MB and start over (no rotation) for consistent use across all modu
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytz
|
||||
from rich.align import Align
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
try:
|
||||
@@ -28,7 +22,7 @@ from .settings import LOG_FILE_PATH, LOG_LEVEL_STR
|
||||
|
||||
|
||||
# Global Rich console for consistent formatting
|
||||
console = Console(stderr=True, force_terminal=True)
|
||||
console = Console(stderr=True)
|
||||
|
||||
|
||||
class OverwriteFileHandler(logging.FileHandler):
|
||||
@@ -45,12 +39,18 @@ class OverwriteFileHandler(logging.FileHandler):
|
||||
delay: Whether to delay file opening
|
||||
"""
|
||||
self.max_bytes = max_bytes
|
||||
self._emit_count = 0
|
||||
self._check_interval = 100
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit a record, checking file size and overwriting if needed."""
|
||||
# Check file size before writing
|
||||
if self.stream and hasattr(self.stream, "name"):
|
||||
"""Emit a record, checking file size periodically and overwriting if needed."""
|
||||
self._emit_count += 1
|
||||
if (
|
||||
self._emit_count % self._check_interval == 0
|
||||
and self.stream
|
||||
and hasattr(self.stream, "name")
|
||||
):
|
||||
try:
|
||||
base_path = Path(self.baseFilename)
|
||||
if base_path.exists():
|
||||
@@ -91,6 +91,28 @@ class OverwriteFileHandler(logging.FileHandler):
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def _create_shared_file_handler() -> OverwriteFileHandler:
|
||||
"""Create the single shared file handler for all loggers.
|
||||
|
||||
Returns:
|
||||
Configured OverwriteFileHandler instance
|
||||
"""
|
||||
numeric_log_level = getattr(logging, LOG_LEVEL_STR, logging.INFO)
|
||||
handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
|
||||
handler.setLevel(numeric_log_level)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
|
||||
)
|
||||
)
|
||||
return handler
|
||||
|
||||
|
||||
# Single shared file handler — all loggers reuse this instance to avoid
|
||||
# race conditions from multiple OverwriteFileHandler instances on the same file.
|
||||
_shared_file_handler = _create_shared_file_handler()
|
||||
|
||||
|
||||
def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
|
||||
"""Set up and configure the logger with console and file handlers.
|
||||
|
||||
@@ -118,19 +140,13 @@ def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=True,
|
||||
tracebacks_show_locals=False,
|
||||
)
|
||||
console_handler.setLevel(numeric_log_level)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File Handler with 10MB cap (overwrites instead of rotating)
|
||||
file_handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
|
||||
file_handler.setLevel(numeric_log_level)
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
# Reuse the shared file handler
|
||||
logger.addHandler(_shared_file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
@@ -157,20 +173,14 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=True,
|
||||
tracebacks_show_locals=False,
|
||||
markup=True,
|
||||
)
|
||||
console_handler.setLevel(numeric_log_level)
|
||||
fastmcp_logger.addHandler(console_handler)
|
||||
|
||||
# File Handler with 10MB cap (overwrites instead of rotating)
|
||||
file_handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
|
||||
file_handler.setLevel(numeric_log_level)
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
fastmcp_logger.addHandler(file_handler)
|
||||
# Reuse the shared file handler
|
||||
fastmcp_logger.addHandler(_shared_file_handler)
|
||||
|
||||
fastmcp_logger.setLevel(numeric_log_level)
|
||||
|
||||
@@ -186,30 +196,19 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=True,
|
||||
tracebacks_show_locals=False,
|
||||
markup=True,
|
||||
)
|
||||
root_console_handler.setLevel(numeric_log_level)
|
||||
root_logger.addHandler(root_console_handler)
|
||||
|
||||
# File Handler for root logger with 10MB cap (overwrites instead of rotating)
|
||||
root_file_handler = OverwriteFileHandler(
|
||||
LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8"
|
||||
)
|
||||
root_file_handler.setLevel(numeric_log_level)
|
||||
root_file_handler.setFormatter(file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
# Reuse the shared file handler for root logger
|
||||
root_logger.addHandler(_shared_file_handler)
|
||||
root_logger.setLevel(numeric_log_level)
|
||||
|
||||
return fastmcp_logger
|
||||
|
||||
|
||||
def setup_uvicorn_logging() -> logging.Logger | None:
|
||||
"""Configure uvicorn and other third-party loggers to use Rich formatting."""
|
||||
# This function is kept for backward compatibility but now delegates to FastMCP
|
||||
return configure_fastmcp_logger_with_rich()
|
||||
|
||||
|
||||
def log_configuration_status(logger: logging.Logger) -> None:
|
||||
"""Log configuration status at startup.
|
||||
|
||||
@@ -242,97 +241,6 @@ def log_configuration_status(logger: logging.Logger) -> None:
|
||||
logger.error(f"Missing required configuration: {config['missing_config']}")
|
||||
|
||||
|
||||
# Development logging helpers for Rich formatting
|
||||
def get_est_timestamp() -> str:
|
||||
"""Get current timestamp in EST timezone with YY/MM/DD format."""
|
||||
est = pytz.timezone("US/Eastern")
|
||||
now = datetime.now(est)
|
||||
return now.strftime("%y/%m/%d %H:%M:%S")
|
||||
|
||||
|
||||
def log_header(title: str) -> None:
|
||||
"""Print a beautiful header panel with Nordic blue styling."""
|
||||
panel = Panel(
|
||||
Align.center(Text(title, style="bold white")),
|
||||
style="#5E81AC", # Nordic blue
|
||||
padding=(0, 2),
|
||||
border_style="#81A1C1", # Light Nordic blue
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
|
||||
def log_with_level_and_indent(message: str, level: str = "info", indent: int = 0) -> None:
|
||||
"""Log a message with specific level and indentation."""
|
||||
timestamp = get_est_timestamp()
|
||||
indent_str = " " * indent
|
||||
|
||||
# Enhanced Nordic color scheme with more blues
|
||||
level_config = {
|
||||
"error": {"color": "#BF616A", "icon": "❌", "style": "bold"}, # Nordic red
|
||||
"warning": {"color": "#EBCB8B", "icon": "⚠️", "style": ""}, # Nordic yellow
|
||||
"success": {"color": "#A3BE8C", "icon": "✅", "style": "bold"}, # Nordic green
|
||||
"info": {"color": "#5E81AC", "icon": "\u2139\ufe0f", "style": "bold"}, # Nordic blue (bold)
|
||||
"status": {"color": "#81A1C1", "icon": "🔍", "style": ""}, # Light Nordic blue
|
||||
"debug": {"color": "#4C566A", "icon": "🐛", "style": ""}, # Nordic dark gray
|
||||
}
|
||||
|
||||
config = level_config.get(
|
||||
level, {"color": "#81A1C1", "icon": "•", "style": ""}
|
||||
) # Default to light Nordic blue
|
||||
|
||||
# Create beautifully formatted text
|
||||
text = Text()
|
||||
|
||||
# Timestamp with Nordic blue styling
|
||||
text.append(f"[{timestamp}]", style="#81A1C1") # Light Nordic blue for timestamps
|
||||
text.append(" ")
|
||||
|
||||
# Indentation with Nordic blue styling
|
||||
if indent > 0:
|
||||
text.append(indent_str, style="#81A1C1")
|
||||
|
||||
# Level icon (only for certain levels)
|
||||
if level in ["error", "warning", "success"]:
|
||||
# Extract emoji from message if it starts with one, to avoid duplication
|
||||
if message and len(message) > 0 and ord(message[0]) >= 0x1F600: # Emoji range
|
||||
# Message already has emoji, don't add icon
|
||||
pass
|
||||
else:
|
||||
text.append(f"{config['icon']} ", style=config["color"])
|
||||
|
||||
# Message content
|
||||
message_style = f"{config['color']} {config['style']}".strip()
|
||||
text.append(message, style=message_style)
|
||||
|
||||
console.print(text)
|
||||
|
||||
|
||||
def log_separator() -> None:
|
||||
"""Print a beautiful separator line with Nordic blue styling."""
|
||||
console.print(Rule(style="#81A1C1"))
|
||||
|
||||
|
||||
# Convenience functions for different log levels
|
||||
def log_error(message: str, indent: int = 0) -> None:
|
||||
log_with_level_and_indent(message, "error", indent)
|
||||
|
||||
|
||||
def log_warning(message: str, indent: int = 0) -> None:
|
||||
log_with_level_and_indent(message, "warning", indent)
|
||||
|
||||
|
||||
def log_success(message: str, indent: int = 0) -> None:
|
||||
log_with_level_and_indent(message, "success", indent)
|
||||
|
||||
|
||||
def log_info(message: str, indent: int = 0) -> None:
|
||||
log_with_level_and_indent(message, "info", indent)
|
||||
|
||||
|
||||
def log_status(message: str, indent: int = 0) -> None:
|
||||
log_with_level_and_indent(message, "status", indent)
|
||||
|
||||
|
||||
# Global logger instance - modules can import this directly
|
||||
if FASTMCP_AVAILABLE:
|
||||
# Use FastMCP logger with Rich formatting
|
||||
@@ -341,5 +249,5 @@ if FASTMCP_AVAILABLE:
|
||||
else:
|
||||
# Fallback to our custom logger if FastMCP is not available
|
||||
logger = setup_logger()
|
||||
# Setup uvicorn logging when module is imported
|
||||
setup_uvicorn_logging()
|
||||
# Also configure FastMCP logger for consistency
|
||||
configure_fastmcp_logger_with_rich()
|
||||
|
||||
@@ -5,6 +5,7 @@ and provides all configuration constants used throughout the application.
|
||||
"""
|
||||
|
||||
import os
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -30,8 +31,11 @@ for dotenv_path in dotenv_paths:
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
break
|
||||
|
||||
# Application Version
|
||||
VERSION = "0.2.0"
|
||||
# Application Version (single source of truth: pyproject.toml)
|
||||
try:
|
||||
VERSION = version("unraid-mcp")
|
||||
except PackageNotFoundError:
|
||||
VERSION = "0.0.0"
|
||||
|
||||
# Core API Configuration
|
||||
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
|
||||
@@ -39,7 +43,7 @@ UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
|
||||
|
||||
# Server Configuration
|
||||
UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970"))
|
||||
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0")
|
||||
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0") # noqa: S104 — intentional for Docker
|
||||
UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower()
|
||||
|
||||
# SSL Configuration
|
||||
@@ -54,7 +58,8 @@ else: # Path to CA bundle
|
||||
# Logging Configuration
|
||||
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
|
||||
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
|
||||
LOGS_DIR = Path("/tmp")
|
||||
# Use /app/logs in Docker, project-relative logs/ directory otherwise
|
||||
LOGS_DIR = Path("/app/logs") if Path("/app").is_dir() else PROJECT_ROOT / "logs"
|
||||
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
||||
|
||||
# Ensure logs directory exists
|
||||
|
||||
@@ -5,7 +5,9 @@ to the Unraid API with proper timeout handling and error management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -22,7 +24,19 @@ from ..core.exceptions import ToolError
|
||||
|
||||
|
||||
# Sensitive keys to redact from debug logs
|
||||
_SENSITIVE_KEYS = {"password", "key", "secret", "token", "apikey"}
|
||||
_SENSITIVE_KEYS = {
|
||||
"password",
|
||||
"key",
|
||||
"secret",
|
||||
"token",
|
||||
"apikey",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"session",
|
||||
"credential",
|
||||
"passphrase",
|
||||
"jwt",
|
||||
}
|
||||
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
@@ -67,7 +81,121 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout:
|
||||
|
||||
# Global connection pool (module-level singleton)
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
_client_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_client_lock() -> asyncio.Lock:
|
||||
"""Get or create the client lock (lazy init to avoid event loop issues)."""
|
||||
global _client_lock
|
||||
if _client_lock is None:
|
||||
_client_lock = asyncio.Lock()
|
||||
return _client_lock
|
||||
|
||||
|
||||
class _RateLimiter:
|
||||
"""Token bucket rate limiter for Unraid API (100 req / 10s hard limit).
|
||||
|
||||
Uses 90 tokens with 9.0 tokens/sec refill for 10% safety headroom.
|
||||
"""
|
||||
|
||||
def __init__(self, max_tokens: int = 90, refill_rate: float = 9.0) -> None:
|
||||
self.max_tokens = max_tokens
|
||||
self.tokens = float(max_tokens)
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.last_refill = time.monotonic()
|
||||
self._lock: asyncio.Lock | None = None
|
||||
|
||||
def _get_lock(self) -> asyncio.Lock:
|
||||
if self._lock is None:
|
||||
self._lock = asyncio.Lock()
|
||||
return self._lock
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""Refill tokens based on elapsed time."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(self.max_tokens, self.tokens + elapsed * self.refill_rate)
|
||||
self.last_refill = now
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Consume one token, waiting if necessary for refill."""
|
||||
while True:
|
||||
async with self._get_lock():
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
return
|
||||
wait_time = (1 - self.tokens) / self.refill_rate
|
||||
|
||||
# Sleep outside the lock so other coroutines aren't blocked
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
|
||||
_rate_limiter = _RateLimiter()
|
||||
|
||||
|
||||
# --- TTL Cache for stable read-only queries ---
|
||||
|
||||
# Queries whose results change infrequently and are safe to cache.
|
||||
# Mutations and volatile queries (metrics, docker, array state) are excluded.
|
||||
_CACHEABLE_QUERY_PREFIXES = frozenset(
|
||||
{
|
||||
"GetNetworkConfig",
|
||||
"GetRegistrationInfo",
|
||||
"GetOwner",
|
||||
"GetFlash",
|
||||
}
|
||||
)
|
||||
|
||||
_CACHE_TTL_SECONDS = 60.0
|
||||
|
||||
|
||||
class _QueryCache:
|
||||
"""Simple TTL cache for GraphQL query responses.
|
||||
|
||||
Keyed by a hash of (query, variables). Entries expire after _CACHE_TTL_SECONDS.
|
||||
Only caches responses for queries whose operation name is in _CACHEABLE_QUERY_PREFIXES.
|
||||
Mutation requests always bypass the cache.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(query: str, variables: dict[str, Any] | None) -> str:
|
||||
raw = query + json.dumps(variables or {}, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def is_cacheable(query: str) -> bool:
|
||||
"""Check if a query is eligible for caching based on its operation name."""
|
||||
if query.lstrip().startswith("mutation"):
|
||||
return False
|
||||
return any(prefix in query for prefix in _CACHEABLE_QUERY_PREFIXES)
|
||||
|
||||
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Return cached result if present and not expired, else None."""
|
||||
key = self._cache_key(query, variables)
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
expires_at, data = entry
|
||||
if time.monotonic() > expires_at:
|
||||
del self._store[key]
|
||||
return None
|
||||
return data
|
||||
|
||||
def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
|
||||
"""Store a query result with TTL expiry."""
|
||||
key = self._cache_key(query, variables)
|
||||
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
|
||||
|
||||
def invalidate_all(self) -> None:
|
||||
"""Clear the entire cache (called after mutations)."""
|
||||
self._store.clear()
|
||||
|
||||
|
||||
_query_cache = _QueryCache()
|
||||
|
||||
|
||||
def is_idempotent_error(error_message: str, operation: str) -> bool:
|
||||
@@ -109,7 +237,7 @@ async def _create_http_client() -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
# Connection pool settings
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=20, max_connections=100, keepalive_expiry=30.0
|
||||
max_keepalive_connections=20, max_connections=20, keepalive_expiry=30.0
|
||||
),
|
||||
# Default timeout (can be overridden per-request)
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
@@ -123,40 +251,35 @@ async def _create_http_client() -> httpx.AsyncClient:
|
||||
async def get_http_client() -> httpx.AsyncClient:
|
||||
"""Get or create shared HTTP client with connection pooling.
|
||||
|
||||
The client is protected by an asyncio lock to prevent concurrent creation.
|
||||
If the existing client was closed (e.g., during shutdown), a new one is created.
|
||||
Uses double-checked locking: fast-path skips the lock when the client
|
||||
is already initialized, only acquiring it for initial creation or
|
||||
recovery after close.
|
||||
|
||||
Returns:
|
||||
Singleton AsyncClient instance with connection pooling enabled
|
||||
"""
|
||||
global _http_client
|
||||
|
||||
async with _client_lock:
|
||||
# Fast-path: skip lock if client is already initialized and open
|
||||
client = _http_client
|
||||
if client is not None and not client.is_closed:
|
||||
return client
|
||||
|
||||
# Slow-path: acquire lock for initialization
|
||||
async with _get_client_lock():
|
||||
if _http_client is None or _http_client.is_closed:
|
||||
_http_client = await _create_http_client()
|
||||
logger.info(
|
||||
"Created shared HTTP client with connection pooling (20 keepalive, 100 max connections)"
|
||||
"Created shared HTTP client with connection pooling (20 keepalive, 20 max connections)"
|
||||
)
|
||||
|
||||
client = _http_client
|
||||
|
||||
# Verify client is still open after releasing the lock.
|
||||
# In asyncio's cooperative model this is unlikely to fail, but guards
|
||||
# against edge cases where close_http_client runs between yield points.
|
||||
if client.is_closed:
|
||||
async with _client_lock:
|
||||
_http_client = await _create_http_client()
|
||||
client = _http_client
|
||||
logger.info("Re-created HTTP client after unexpected close")
|
||||
|
||||
return client
|
||||
return _http_client
|
||||
|
||||
|
||||
async def close_http_client() -> None:
|
||||
"""Close the shared HTTP client (call on server shutdown)."""
|
||||
global _http_client
|
||||
|
||||
async with _client_lock:
|
||||
async with _get_client_lock():
|
||||
if _http_client is not None:
|
||||
await _http_client.aclose()
|
||||
_http_client = None
|
||||
@@ -190,6 +313,14 @@ async def make_graphql_request(
|
||||
if not UNRAID_API_KEY:
|
||||
raise ToolError("UNRAID_API_KEY not configured")
|
||||
|
||||
# Check TTL cache for stable read-only queries
|
||||
is_mutation = query.lstrip().startswith("mutation")
|
||||
if not is_mutation and _query_cache.is_cacheable(query):
|
||||
cached = _query_cache.get(query, variables)
|
||||
if cached is not None:
|
||||
logger.debug("Returning cached response for query")
|
||||
return cached
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-API-Key": UNRAID_API_KEY,
|
||||
@@ -205,17 +336,31 @@ async def make_graphql_request(
|
||||
logger.debug(f"Variables: {_redact_sensitive(variables)}")
|
||||
|
||||
try:
|
||||
# Rate limit: consume a token before making the request
|
||||
await _rate_limiter.acquire()
|
||||
|
||||
# Get the shared HTTP client with connection pooling
|
||||
client = await get_http_client()
|
||||
|
||||
# Override timeout if custom timeout specified
|
||||
# Retry loop for 429 rate limit responses
|
||||
post_kwargs: dict[str, Any] = {"json": payload, "headers": headers}
|
||||
if custom_timeout is not None:
|
||||
response = await client.post(
|
||||
UNRAID_API_URL, json=payload, headers=headers, timeout=custom_timeout
|
||||
)
|
||||
else:
|
||||
response = await client.post(UNRAID_API_URL, json=payload, headers=headers)
|
||||
post_kwargs["timeout"] = custom_timeout
|
||||
|
||||
response: httpx.Response | None = None
|
||||
for attempt in range(3):
|
||||
response = await client.post(UNRAID_API_URL, **post_kwargs)
|
||||
if response.status_code == 429:
|
||||
backoff = 2**attempt
|
||||
logger.warning(
|
||||
f"Rate limited (429) by Unraid API, retrying in {backoff}s (attempt {attempt + 1}/3)"
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
break
|
||||
|
||||
if response is None: # pragma: no cover — guaranteed by loop
|
||||
raise ToolError("No response received after retry attempts")
|
||||
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
|
||||
|
||||
response_data = response.json()
|
||||
@@ -245,14 +390,27 @@ async def make_graphql_request(
|
||||
|
||||
logger.debug("GraphQL request successful.")
|
||||
data = response_data.get("data", {})
|
||||
return data if isinstance(data, dict) else {} # Ensure we return dict
|
||||
result = data if isinstance(data, dict) else {} # Ensure we return dict
|
||||
|
||||
# Invalidate cache on mutations; cache eligible query results
|
||||
if is_mutation:
|
||||
_query_cache.invalidate_all()
|
||||
elif _query_cache.is_cacheable(query):
|
||||
_query_cache.put(query, variables, result)
|
||||
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Log full details internally; only expose status code to MCP client
|
||||
logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
|
||||
raise ToolError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
|
||||
raise ToolError(
|
||||
f"Unraid API returned HTTP {e.response.status_code}. Check server logs for details."
|
||||
) from e
|
||||
except httpx.RequestError as e:
|
||||
# Log full error internally; give safe summary to MCP client
|
||||
logger.error(f"Request error occurred: {e}")
|
||||
raise ToolError(f"Network connection error: {e!s}") from e
|
||||
raise ToolError(f"Network error connecting to Unraid API: {type(e).__name__}") from e
|
||||
except json.JSONDecodeError as e:
|
||||
# Log full decode error; give safe summary to MCP client
|
||||
logger.error(f"Failed to decode JSON response: {e}")
|
||||
raise ToolError(f"Invalid JSON response from Unraid API: {e!s}") from e
|
||||
raise ToolError("Unraid API returned an invalid response (not valid JSON)") from e
|
||||
|
||||
@@ -4,6 +4,10 @@ This module defines custom exception classes for consistent error handling
|
||||
throughout the application, with proper integration to FastMCP's error system.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastmcp.exceptions import ToolError as FastMCPToolError
|
||||
|
||||
|
||||
@@ -19,36 +23,26 @@ class ToolError(FastMCPToolError):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(ToolError):
|
||||
"""Raised when there are configuration-related errors."""
|
||||
@contextlib.contextmanager
|
||||
def tool_error_handler(
|
||||
tool_name: str,
|
||||
action: str,
|
||||
logger: logging.Logger,
|
||||
) -> Generator[None]:
|
||||
"""Context manager that standardizes tool error handling.
|
||||
|
||||
pass
|
||||
Re-raises ToolError as-is. Catches all other exceptions, logs them
|
||||
with full traceback, and wraps them in ToolError with a descriptive message.
|
||||
|
||||
|
||||
class UnraidAPIError(ToolError):
|
||||
"""Raised when the Unraid API returns an error or is unreachable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptionError(ToolError):
|
||||
"""Raised when there are WebSocket subscription-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(ToolError):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IdempotentOperationError(ToolError):
|
||||
"""Raised when an operation is idempotent (already in desired state).
|
||||
|
||||
This is used internally to signal that an operation was already complete,
|
||||
which should typically be converted to a success response rather than
|
||||
propagated as an error to the user.
|
||||
Args:
|
||||
tool_name: The tool name for error messages (e.g., "docker", "vm").
|
||||
action: The current action being executed.
|
||||
logger: The logger instance to use for error logging.
|
||||
"""
|
||||
|
||||
pass
|
||||
try:
|
||||
yield
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e
|
||||
|
||||
@@ -9,27 +9,33 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class SubscriptionData:
|
||||
"""Container for subscription data with metadata."""
|
||||
"""Container for subscription data with metadata.
|
||||
|
||||
Note: last_updated must be timezone-aware (use datetime.now(UTC)).
|
||||
"""
|
||||
|
||||
data: dict[str, Any]
|
||||
last_updated: datetime
|
||||
last_updated: datetime # Must be timezone-aware (UTC)
|
||||
subscription_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class SystemHealth:
|
||||
"""Container for system health status information."""
|
||||
"""Container for system health status information.
|
||||
|
||||
Note: last_checked must be timezone-aware (use datetime.now(UTC)).
|
||||
"""
|
||||
|
||||
is_healthy: bool
|
||||
issues: list[str]
|
||||
warnings: list[str]
|
||||
last_checked: datetime
|
||||
last_checked: datetime # Must be timezone-aware (UTC)
|
||||
component_status: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class APIResponse:
|
||||
"""Container for standardized API response data."""
|
||||
|
||||
|
||||
68
unraid_mcp/core/utils.py
Normal file
68
unraid_mcp/core/utils.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Shared utility functions for Unraid MCP tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
||||
"""Safely traverse nested dict keys, handling None intermediates.
|
||||
|
||||
Args:
|
||||
data: The root dictionary to traverse.
|
||||
*keys: Sequence of keys to follow.
|
||||
default: Value to return if any key is missing or None.
|
||||
|
||||
Returns:
|
||||
The value at the end of the key chain, or default if unreachable.
|
||||
"""
|
||||
current = data
|
||||
for key in keys:
|
||||
if not isinstance(current, dict):
|
||||
return default
|
||||
current = current.get(key)
|
||||
return current if current is not None else default
|
||||
|
||||
|
||||
def format_bytes(bytes_value: int | None) -> str:
|
||||
"""Format byte values into human-readable sizes.
|
||||
|
||||
Args:
|
||||
bytes_value: Number of bytes, or None.
|
||||
|
||||
Returns:
|
||||
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
|
||||
"""
|
||||
if bytes_value is None:
|
||||
return "N/A"
|
||||
try:
|
||||
value = float(int(bytes_value))
|
||||
except (ValueError, TypeError):
|
||||
return "N/A"
|
||||
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
||||
if value < 1024.0:
|
||||
return f"{value:.2f} {unit}"
|
||||
value /= 1024.0
|
||||
return f"{value:.2f} EB"
|
||||
|
||||
|
||||
def format_kb(k: Any) -> str:
|
||||
"""Format kilobyte values into human-readable sizes.
|
||||
|
||||
Args:
|
||||
k: Number of kilobytes, or None.
|
||||
|
||||
Returns:
|
||||
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
|
||||
"""
|
||||
if k is None:
|
||||
return "N/A"
|
||||
try:
|
||||
k = int(k)
|
||||
except (ValueError, TypeError):
|
||||
return "N/A"
|
||||
if k >= 1024 * 1024 * 1024:
|
||||
return f"{k / (1024 * 1024 * 1024):.2f} TB"
|
||||
if k >= 1024 * 1024:
|
||||
return f"{k / (1024 * 1024):.2f} GB"
|
||||
if k >= 1024:
|
||||
return f"{k / 1024:.2f} MB"
|
||||
return f"{k} KB"
|
||||
@@ -15,8 +15,11 @@ from .config.settings import (
|
||||
UNRAID_MCP_HOST,
|
||||
UNRAID_MCP_PORT,
|
||||
UNRAID_MCP_TRANSPORT,
|
||||
UNRAID_VERIFY_SSL,
|
||||
VERSION,
|
||||
validate_required_config,
|
||||
)
|
||||
from .subscriptions.diagnostics import register_diagnostic_tools
|
||||
from .subscriptions.resources import register_subscription_resources
|
||||
from .tools.array import register_array_tool
|
||||
from .tools.docker import register_docker_tool
|
||||
@@ -44,9 +47,10 @@ mcp = FastMCP(
|
||||
def register_all_modules() -> None:
|
||||
"""Register all tools and resources with the MCP instance."""
|
||||
try:
|
||||
# Register subscription resources first
|
||||
# Register subscription resources and diagnostic tools
|
||||
register_subscription_resources(mcp)
|
||||
logger.info("Subscription resources registered")
|
||||
register_diagnostic_tools(mcp)
|
||||
logger.info("Subscription resources and diagnostic tools registered")
|
||||
|
||||
# Register all consolidated tools
|
||||
registrars = [
|
||||
@@ -73,6 +77,15 @@ def register_all_modules() -> None:
|
||||
|
||||
def run_server() -> None:
|
||||
"""Run the MCP server with the configured transport."""
|
||||
# Validate required configuration before anything else
|
||||
is_valid, missing = validate_required_config()
|
||||
if not is_valid:
|
||||
logger.critical(
|
||||
f"Missing required configuration: {', '.join(missing)}. "
|
||||
"Set these environment variables or add them to your .env file."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Log configuration
|
||||
if UNRAID_API_URL:
|
||||
logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...")
|
||||
@@ -88,6 +101,13 @@ def run_server() -> None:
|
||||
logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}")
|
||||
logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}")
|
||||
|
||||
if UNRAID_VERIFY_SSL is False:
|
||||
logger.warning(
|
||||
"SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). "
|
||||
"Connections to Unraid API are vulnerable to man-in-the-middle attacks. "
|
||||
"Only use this in trusted networks or for development."
|
||||
)
|
||||
|
||||
# Register all modules
|
||||
register_all_modules()
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ development and debugging purposes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import websockets
|
||||
@@ -19,7 +21,58 @@ from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
||||
from ..core.exceptions import ToolError
|
||||
from .manager import subscription_manager
|
||||
from .resources import ensure_subscriptions_started
|
||||
from .utils import build_ws_ssl_context
|
||||
from .utils import build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
||||
{
|
||||
"logFileSubscription",
|
||||
"containerStatsSubscription",
|
||||
"cpuSubscription",
|
||||
"memorySubscription",
|
||||
"arraySubscription",
|
||||
"networkSubscription",
|
||||
"dockerSubscription",
|
||||
"vmSubscription",
|
||||
}
|
||||
)
|
||||
|
||||
# Pattern: must start with "subscription", contain only a known subscription name,
|
||||
# and not contain mutation/query keywords or semicolons (prevents injection).
|
||||
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _validate_subscription_query(query: str) -> str:
|
||||
"""Validate that a subscription query is safe to execute.
|
||||
|
||||
Only allows subscription operations targeting whitelisted subscription names.
|
||||
Rejects any query containing mutation/query keywords.
|
||||
|
||||
Returns:
|
||||
The extracted subscription name.
|
||||
|
||||
Raises:
|
||||
ToolError: If the query fails validation.
|
||||
"""
|
||||
if _FORBIDDEN_KEYWORDS.search(query):
|
||||
raise ToolError("Query rejected: must be a subscription, not a mutation or query.")
|
||||
|
||||
match = _SUBSCRIPTION_NAME_PATTERN.match(query)
|
||||
if not match:
|
||||
raise ToolError(
|
||||
"Query rejected: must start with 'subscription' and contain a valid "
|
||||
"subscription operation. Example: subscription { logFileSubscription { ... } }"
|
||||
)
|
||||
|
||||
sub_name = match.group(1)
|
||||
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
|
||||
raise ToolError(
|
||||
f"Subscription '{sub_name}' is not allowed. "
|
||||
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
|
||||
)
|
||||
|
||||
return sub_name
|
||||
|
||||
|
||||
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
@@ -34,6 +87,10 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
"""Test a GraphQL subscription query directly to debug schema issues.
|
||||
|
||||
Use this to find working subscription field names and structure.
|
||||
Only whitelisted subscriptions are allowed (logFileSubscription,
|
||||
containerStatsSubscription, cpuSubscription, memorySubscription,
|
||||
arraySubscription, networkSubscription, dockerSubscription,
|
||||
vmSubscription).
|
||||
|
||||
Args:
|
||||
subscription_query: The GraphQL subscription query to test
|
||||
@@ -41,16 +98,16 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
Returns:
|
||||
Dict containing test results and response data
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[TEST_SUBSCRIPTION] Testing query: {subscription_query}")
|
||||
# Validate before any network I/O
|
||||
sub_name = _validate_subscription_query(subscription_query)
|
||||
|
||||
# Build WebSocket URL
|
||||
if not UNRAID_API_URL:
|
||||
raise ToolError("UNRAID_API_URL is not configured")
|
||||
ws_url = (
|
||||
UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://")
|
||||
+ "/graphql"
|
||||
)
|
||||
try:
|
||||
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'")
|
||||
|
||||
try:
|
||||
ws_url = build_ws_url()
|
||||
except ValueError as e:
|
||||
raise ToolError(str(e)) from e
|
||||
|
||||
ssl_context = build_ws_ssl_context(ws_url)
|
||||
|
||||
@@ -59,6 +116,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
ws_url,
|
||||
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||
ssl=ssl_context,
|
||||
open_timeout=10,
|
||||
ping_interval=30,
|
||||
ping_timeout=10,
|
||||
) as websocket:
|
||||
@@ -122,14 +180,14 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
logger.info("[DIAGNOSTIC] Running subscription diagnostics...")
|
||||
|
||||
# Get comprehensive status
|
||||
status = subscription_manager.get_subscription_status()
|
||||
status = await subscription_manager.get_subscription_status()
|
||||
|
||||
# Initialize connection issues list with proper type
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
|
||||
# Add environment info with explicit typing
|
||||
diagnostic_info: dict[str, Any] = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"environment": {
|
||||
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
||||
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
||||
@@ -152,17 +210,9 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||
},
|
||||
}
|
||||
|
||||
# Calculate WebSocket URL
|
||||
if UNRAID_API_URL:
|
||||
if UNRAID_API_URL.startswith("https://"):
|
||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||
elif UNRAID_API_URL.startswith("http://"):
|
||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||
else:
|
||||
ws_url = UNRAID_API_URL
|
||||
if not ws_url.endswith("/graphql"):
|
||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||
diagnostic_info["environment"]["websocket_url"] = ws_url
|
||||
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured)
|
||||
with contextlib.suppress(ValueError):
|
||||
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
|
||||
|
||||
# Analyze issues
|
||||
for sub_name, sub_status in status.items():
|
||||
|
||||
@@ -8,16 +8,50 @@ error handling, reconnection logic, and authentication.
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import websockets
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
||||
from ..config.settings import UNRAID_API_KEY
|
||||
from ..core.client import _redact_sensitive
|
||||
from ..core.types import SubscriptionData
|
||||
from .utils import build_ws_ssl_context
|
||||
from .utils import build_ws_ssl_context, build_ws_url
|
||||
|
||||
|
||||
# Resource data size limits to prevent unbounded memory growth
|
||||
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1MB
|
||||
_MAX_RESOURCE_DATA_LINES = 5_000
|
||||
# Minimum stable connection duration (seconds) before resetting reconnect counter
|
||||
_STABLE_CONNECTION_SECONDS = 30
|
||||
|
||||
|
||||
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cap log content in subscription data to prevent unbounded memory growth.
|
||||
|
||||
If the data contains a 'content' field (from log subscriptions) that exceeds
|
||||
size limits, truncate to the most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||
"""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
data[key] = _cap_log_content(value)
|
||||
elif (
|
||||
key == "content"
|
||||
and isinstance(value, str)
|
||||
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
||||
):
|
||||
lines = value.splitlines()
|
||||
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
||||
truncated = "\n".join(lines[-_MAX_RESOURCE_DATA_LINES:])
|
||||
logger.warning(
|
||||
f"[RESOURCE] Capped log content from {len(lines)} to "
|
||||
f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)"
|
||||
)
|
||||
data[key] = truncated
|
||||
return data
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
@@ -26,7 +60,6 @@ class SubscriptionManager:
|
||||
def __init__(self) -> None:
|
||||
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
|
||||
self.resource_data: dict[str, SubscriptionData] = {}
|
||||
self.websocket: websockets.WebSocketServerProtocol | None = None
|
||||
self.subscription_lock = asyncio.Lock()
|
||||
|
||||
# Configuration
|
||||
@@ -37,6 +70,7 @@ class SubscriptionManager:
|
||||
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
|
||||
self.connection_states: dict[str, str] = {} # Track connection state per subscription
|
||||
self.last_error: dict[str, str] = {} # Track last error per subscription
|
||||
self._connection_start_times: dict[str, float] = {} # Track when connections started
|
||||
|
||||
# Define subscription configurations
|
||||
self.subscription_configs = {
|
||||
@@ -165,20 +199,7 @@ class SubscriptionManager:
|
||||
break
|
||||
|
||||
try:
|
||||
# Build WebSocket URL with detailed logging
|
||||
if not UNRAID_API_URL:
|
||||
raise ValueError("UNRAID_API_URL is not configured")
|
||||
|
||||
if UNRAID_API_URL.startswith("https://"):
|
||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||
elif UNRAID_API_URL.startswith("http://"):
|
||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||
else:
|
||||
ws_url = UNRAID_API_URL
|
||||
|
||||
if not ws_url.endswith("/graphql"):
|
||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||
|
||||
ws_url = build_ws_url()
|
||||
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
|
||||
logger.debug(
|
||||
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}"
|
||||
@@ -195,6 +216,7 @@ class SubscriptionManager:
|
||||
async with websockets.connect(
|
||||
ws_url,
|
||||
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||
open_timeout=connect_timeout,
|
||||
ping_interval=20,
|
||||
ping_timeout=10,
|
||||
close_timeout=10,
|
||||
@@ -206,9 +228,9 @@ class SubscriptionManager:
|
||||
)
|
||||
self.connection_states[subscription_name] = "connected"
|
||||
|
||||
# Reset retry count on successful connection
|
||||
self.reconnect_attempts[subscription_name] = 0
|
||||
retry_delay = 5 # Reset delay
|
||||
# Track connection start time — only reset retry counter
|
||||
# after the connection proves stable (>30s connected)
|
||||
self._connection_start_times[subscription_name] = time.monotonic()
|
||||
|
||||
# Initialize GraphQL-WS protocol
|
||||
logger.debug(
|
||||
@@ -290,7 +312,9 @@ class SubscriptionManager:
|
||||
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
|
||||
)
|
||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
|
||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Variables: {variables}")
|
||||
logger.debug(
|
||||
f"[SUBSCRIPTION:{subscription_name}] Variables: {_redact_sensitive(variables)}"
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps(subscription_message))
|
||||
logger.info(
|
||||
@@ -326,9 +350,14 @@ class SubscriptionManager:
|
||||
logger.info(
|
||||
f"[DATA:{subscription_name}] Received subscription data update"
|
||||
)
|
||||
capped_data = (
|
||||
_cap_log_content(payload["data"])
|
||||
if isinstance(payload["data"], dict)
|
||||
else payload["data"]
|
||||
)
|
||||
self.resource_data[subscription_name] = SubscriptionData(
|
||||
data=payload["data"],
|
||||
last_updated=datetime.now(),
|
||||
data=capped_data,
|
||||
last_updated=datetime.now(UTC),
|
||||
subscription_type=subscription_name,
|
||||
)
|
||||
logger.debug(
|
||||
@@ -427,6 +456,26 @@ class SubscriptionManager:
|
||||
self.last_error[subscription_name] = error_msg
|
||||
self.connection_states[subscription_name] = "error"
|
||||
|
||||
# Check if connection was stable before deciding on retry behavior
|
||||
start_time = self._connection_start_times.get(subscription_name)
|
||||
if start_time is not None:
|
||||
connected_duration = time.monotonic() - start_time
|
||||
if connected_duration >= _STABLE_CONNECTION_SECONDS:
|
||||
# Connection was stable — reset retry counter and backoff
|
||||
logger.info(
|
||||
f"[WEBSOCKET:{subscription_name}] Connection was stable "
|
||||
f"({connected_duration:.0f}s >= {_STABLE_CONNECTION_SECONDS}s), "
|
||||
f"resetting retry counter"
|
||||
)
|
||||
self.reconnect_attempts[subscription_name] = 0
|
||||
retry_delay = 5
|
||||
else:
|
||||
logger.warning(
|
||||
f"[WEBSOCKET:{subscription_name}] Connection was unstable "
|
||||
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
|
||||
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
|
||||
)
|
||||
|
||||
# Calculate backoff delay
|
||||
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
||||
logger.info(
|
||||
@@ -435,13 +484,14 @@ class SubscriptionManager:
|
||||
self.connection_states[subscription_name] = "reconnecting"
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
||||
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
||||
"""Get current resource data with enhanced logging."""
|
||||
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
||||
|
||||
async with self.subscription_lock:
|
||||
if resource_name in self.resource_data:
|
||||
data = self.resource_data[resource_name]
|
||||
age_seconds = (datetime.now() - data.last_updated).total_seconds()
|
||||
age_seconds = (datetime.now(UTC) - data.last_updated).total_seconds()
|
||||
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
|
||||
return data.data
|
||||
logger.debug(f"[RESOURCE:{resource_name}] No data available")
|
||||
@@ -453,10 +503,11 @@ class SubscriptionManager:
|
||||
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
|
||||
return active
|
||||
|
||||
def get_subscription_status(self) -> dict[str, dict[str, Any]]:
|
||||
async def get_subscription_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get detailed status of all subscriptions for diagnostics."""
|
||||
status = {}
|
||||
|
||||
async with self.subscription_lock:
|
||||
for sub_name, config in self.subscription_configs.items():
|
||||
sub_status = {
|
||||
"config": {
|
||||
@@ -475,7 +526,7 @@ class SubscriptionManager:
|
||||
# Add data info if available
|
||||
if sub_name in self.resource_data:
|
||||
data_info = self.resource_data[sub_name]
|
||||
age_seconds = (datetime.now() - data_info.last_updated).total_seconds()
|
||||
age_seconds = (datetime.now(UTC) - data_info.last_updated).total_seconds()
|
||||
sub_status["data"] = {
|
||||
"available": True,
|
||||
"last_updated": data_info.last_updated.isoformat(),
|
||||
|
||||
@@ -82,7 +82,7 @@ def register_subscription_resources(mcp: FastMCP) -> None:
|
||||
async def logs_stream_resource() -> str:
|
||||
"""Real-time log stream data from subscription."""
|
||||
await ensure_subscriptions_started()
|
||||
data = subscription_manager.get_resource_data("logFileSubscription")
|
||||
data = await subscription_manager.get_resource_data("logFileSubscription")
|
||||
if data:
|
||||
return json.dumps(data, indent=2)
|
||||
return json.dumps(
|
||||
|
||||
@@ -2,7 +2,34 @@
|
||||
|
||||
import ssl as _ssl
|
||||
|
||||
from ..config.settings import UNRAID_VERIFY_SSL
|
||||
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL
|
||||
|
||||
|
||||
def build_ws_url() -> str:
|
||||
"""Build a WebSocket URL from the configured UNRAID_API_URL.
|
||||
|
||||
Converts http(s) scheme to ws(s) and ensures /graphql path suffix.
|
||||
|
||||
Returns:
|
||||
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
|
||||
|
||||
Raises:
|
||||
ValueError: If UNRAID_API_URL is not configured.
|
||||
"""
|
||||
if not UNRAID_API_URL:
|
||||
raise ValueError("UNRAID_API_URL is not configured")
|
||||
|
||||
if UNRAID_API_URL.startswith("https://"):
|
||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||
elif UNRAID_API_URL.startswith("http://"):
|
||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||
else:
|
||||
ws_url = UNRAID_API_URL
|
||||
|
||||
if not ws_url.endswith("/graphql"):
|
||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||
|
||||
return ws_url
|
||||
|
||||
|
||||
def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -74,7 +74,7 @@ def register_array_tool(mcp: FastMCP) -> None:
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
try:
|
||||
with tool_error_handler("array", action, logger):
|
||||
logger.info(f"Executing unraid_array action={action}")
|
||||
|
||||
if action in QUERIES:
|
||||
@@ -95,10 +95,4 @@ def register_array_tool(mcp: FastMCP) -> None:
|
||||
"data": data,
|
||||
}
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_array action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute array/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Array tool registered successfully")
|
||||
|
||||
@@ -11,7 +11,8 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
from ..core.utils import safe_get
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -99,6 +100,10 @@ MUTATIONS: dict[str, str] = {
|
||||
}
|
||||
|
||||
DESTRUCTIVE_ACTIONS = {"remove"}
|
||||
_MUTATION_ACTIONS = {"start", "stop", "restart", "pause", "unpause", "remove", "update"}
|
||||
# NOTE (Code-M-07): "details" and "logs" are listed here because they require a
|
||||
# container_id parameter, but unlike mutations they use fuzzy name matching (not
|
||||
# strict). This is intentional: read-only queries are safe with fuzzy matching.
|
||||
_ACTIONS_REQUIRING_CONTAINER_ID = {
|
||||
"start",
|
||||
"stop",
|
||||
@@ -111,6 +116,7 @@ _ACTIONS_REQUIRING_CONTAINER_ID = {
|
||||
"logs",
|
||||
}
|
||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"restart"}
|
||||
_MAX_TAIL_LINES = 10_000
|
||||
|
||||
DOCKER_ACTIONS = Literal[
|
||||
"list",
|
||||
@@ -130,33 +136,28 @@ DOCKER_ACTIONS = Literal[
|
||||
"check_updates",
|
||||
]
|
||||
|
||||
# Docker container IDs: 64 hex chars + optional suffix (e.g., ":local")
|
||||
# Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local")
|
||||
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
|
||||
|
||||
|
||||
def _safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
||||
"""Safely traverse nested dict keys, handling None intermediates."""
|
||||
current = data
|
||||
for key in keys:
|
||||
if not isinstance(current, dict):
|
||||
return default
|
||||
current = current.get(key)
|
||||
return current if current is not None else default
|
||||
# Short hex prefix: at least 12 hex chars (standard Docker short ID length)
|
||||
_DOCKER_SHORT_ID_PATTERN = re.compile(r"^[a-f0-9]{12,63}$", re.IGNORECASE)
|
||||
|
||||
|
||||
def find_container_by_identifier(
|
||||
identifier: str, containers: list[dict[str, Any]]
|
||||
identifier: str, containers: list[dict[str, Any]], *, strict: bool = False
|
||||
) -> dict[str, Any] | None:
|
||||
"""Find a container by ID or name with fuzzy matching.
|
||||
"""Find a container by ID or name with optional fuzzy matching.
|
||||
|
||||
Match priority:
|
||||
1. Exact ID match
|
||||
2. Exact name match (case-sensitive)
|
||||
|
||||
When strict=False (default), also tries:
|
||||
3. Name starts with identifier (case-insensitive)
|
||||
4. Name contains identifier as substring (case-insensitive)
|
||||
|
||||
Note: Short identifiers (e.g. "db") may match unintended containers
|
||||
via substring. Use more specific names or IDs for precision.
|
||||
When strict=True, only exact matches (1 & 2) are used.
|
||||
Use strict=True for mutations to prevent targeting the wrong container.
|
||||
"""
|
||||
if not containers:
|
||||
return None
|
||||
@@ -168,20 +169,24 @@ def find_container_by_identifier(
|
||||
if identifier in c.get("names", []):
|
||||
return c
|
||||
|
||||
# Strict mode: no fuzzy matching allowed
|
||||
if strict:
|
||||
return None
|
||||
|
||||
id_lower = identifier.lower()
|
||||
|
||||
# Priority 3: prefix match (more precise than substring)
|
||||
for c in containers:
|
||||
for name in c.get("names", []):
|
||||
if name.lower().startswith(id_lower):
|
||||
logger.info(f"Prefix match: '{identifier}' -> '{name}'")
|
||||
logger.debug(f"Prefix match: '{identifier}' -> '{name}'")
|
||||
return c
|
||||
|
||||
# Priority 4: substring match (least precise)
|
||||
for c in containers:
|
||||
for name in c.get("names", []):
|
||||
if id_lower in name.lower():
|
||||
logger.info(f"Substring match: '{identifier}' -> '{name}'")
|
||||
logger.debug(f"Substring match: '{identifier}' -> '{name}'")
|
||||
return c
|
||||
|
||||
return None
|
||||
@@ -195,26 +200,61 @@ def get_available_container_names(containers: list[dict[str, Any]]) -> list[str]
|
||||
return names
|
||||
|
||||
|
||||
async def _resolve_container_id(container_id: str) -> str:
|
||||
"""Resolve a container name/identifier to its actual PrefixedID."""
|
||||
def _looks_like_container_id(identifier: str) -> bool:
|
||||
"""Check if an identifier looks like a container ID (full or short hex prefix)."""
|
||||
return bool(_DOCKER_ID_PATTERN.match(identifier) or _DOCKER_SHORT_ID_PATTERN.match(identifier))
|
||||
|
||||
|
||||
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
|
||||
"""Resolve a container name/identifier to its actual PrefixedID.
|
||||
|
||||
Optimization: if the identifier is a full 64-char hex ID (with optional
|
||||
:suffix), skip the container list fetch entirely and use it directly.
|
||||
If it's a short hex prefix (12-63 chars), fetch the list and match by
|
||||
ID prefix. Only fetch the container list for name-based lookups.
|
||||
|
||||
Args:
|
||||
container_id: Container name or ID to resolve
|
||||
strict: When True, only exact name/ID matches are allowed (no fuzzy).
|
||||
Use for mutations to prevent targeting the wrong container.
|
||||
"""
|
||||
# Full PrefixedID: skip the list fetch entirely
|
||||
if _DOCKER_ID_PATTERN.match(container_id):
|
||||
return container_id
|
||||
|
||||
logger.info(f"Resolving container identifier '{container_id}'")
|
||||
logger.info(f"Resolving container identifier '{container_id}' (strict={strict})")
|
||||
list_query = """
|
||||
query ResolveContainerID {
|
||||
docker { containers(skipCache: true) { id names } }
|
||||
}
|
||||
"""
|
||||
data = await make_graphql_request(list_query)
|
||||
containers = _safe_get(data, "docker", "containers", default=[])
|
||||
resolved = find_container_by_identifier(container_id, containers)
|
||||
containers = safe_get(data, "docker", "containers", default=[])
|
||||
|
||||
# Short hex prefix: match by ID prefix before trying name matching
|
||||
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
|
||||
id_lower = container_id.lower()
|
||||
for c in containers:
|
||||
cid = (c.get("id") or "").lower()
|
||||
if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower):
|
||||
actual_id = str(c.get("id", ""))
|
||||
logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'")
|
||||
return actual_id
|
||||
|
||||
resolved = find_container_by_identifier(container_id, containers, strict=strict)
|
||||
if resolved:
|
||||
actual_id = str(resolved.get("id", ""))
|
||||
logger.info(f"Resolved '{container_id}' -> '{actual_id}'")
|
||||
return actual_id
|
||||
|
||||
available = get_available_container_names(containers)
|
||||
if strict:
|
||||
msg = (
|
||||
f"Container '{container_id}' not found by exact match. "
|
||||
f"Mutations require an exact container name or full ID — "
|
||||
f"fuzzy/substring matching is not allowed for safety."
|
||||
)
|
||||
else:
|
||||
msg = f"Container '{container_id}' not found."
|
||||
if available:
|
||||
msg += f" Available: {', '.join(available[:10])}"
|
||||
@@ -264,38 +304,40 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
||||
if action == "network_details" and not network_id:
|
||||
raise ToolError("network_id is required for 'network_details' action")
|
||||
|
||||
try:
|
||||
if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES:
|
||||
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||
|
||||
with tool_error_handler("docker", action, logger):
|
||||
logger.info(f"Executing unraid_docker action={action}")
|
||||
|
||||
# --- Read-only queries ---
|
||||
if action == "list":
|
||||
data = await make_graphql_request(QUERIES["list"])
|
||||
containers = _safe_get(data, "docker", "containers", default=[])
|
||||
return {"containers": list(containers) if isinstance(containers, list) else []}
|
||||
containers = safe_get(data, "docker", "containers", default=[])
|
||||
return {"containers": containers}
|
||||
|
||||
if action == "details":
|
||||
# Resolve name -> ID first (skips list fetch if already an ID)
|
||||
actual_id = await _resolve_container_id(container_id or "")
|
||||
data = await make_graphql_request(QUERIES["details"])
|
||||
containers = _safe_get(data, "docker", "containers", default=[])
|
||||
container = find_container_by_identifier(container_id or "", containers)
|
||||
if container:
|
||||
return container
|
||||
available = get_available_container_names(containers)
|
||||
msg = f"Container '{container_id}' not found."
|
||||
if available:
|
||||
msg += f" Available: {', '.join(available[:10])}"
|
||||
raise ToolError(msg)
|
||||
containers = safe_get(data, "docker", "containers", default=[])
|
||||
# Match by resolved ID (exact match, no second list fetch needed)
|
||||
for c in containers:
|
||||
if c.get("id") == actual_id:
|
||||
return c
|
||||
raise ToolError(f"Container '{container_id}' not found in details response.")
|
||||
|
||||
if action == "logs":
|
||||
actual_id = await _resolve_container_id(container_id or "")
|
||||
data = await make_graphql_request(
|
||||
QUERIES["logs"], {"id": actual_id, "tail": tail_lines}
|
||||
)
|
||||
return {"logs": _safe_get(data, "docker", "logs")}
|
||||
return {"logs": safe_get(data, "docker", "logs")}
|
||||
|
||||
if action == "networks":
|
||||
data = await make_graphql_request(QUERIES["networks"])
|
||||
networks = data.get("dockerNetworks", [])
|
||||
return {"networks": list(networks) if isinstance(networks, list) else []}
|
||||
return {"networks": networks}
|
||||
|
||||
if action == "network_details":
|
||||
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
|
||||
@@ -303,17 +345,17 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
||||
|
||||
if action == "port_conflicts":
|
||||
data = await make_graphql_request(QUERIES["port_conflicts"])
|
||||
conflicts = _safe_get(data, "docker", "portConflicts", default=[])
|
||||
return {"port_conflicts": list(conflicts) if isinstance(conflicts, list) else []}
|
||||
conflicts = safe_get(data, "docker", "portConflicts", default=[])
|
||||
return {"port_conflicts": conflicts}
|
||||
|
||||
if action == "check_updates":
|
||||
data = await make_graphql_request(QUERIES["check_updates"])
|
||||
statuses = _safe_get(data, "docker", "containerUpdateStatuses", default=[])
|
||||
return {"update_statuses": list(statuses) if isinstance(statuses, list) else []}
|
||||
statuses = safe_get(data, "docker", "containerUpdateStatuses", default=[])
|
||||
return {"update_statuses": statuses}
|
||||
|
||||
# --- Mutations ---
|
||||
# --- Mutations (strict matching: no fuzzy/substring) ---
|
||||
if action == "restart":
|
||||
actual_id = await _resolve_container_id(container_id or "")
|
||||
actual_id = await _resolve_container_id(container_id or "", strict=True)
|
||||
# Stop (idempotent: treat "already stopped" as success)
|
||||
stop_data = await make_graphql_request(
|
||||
MUTATIONS["stop"],
|
||||
@@ -330,7 +372,7 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
||||
if start_data.get("idempotent_success"):
|
||||
result = {}
|
||||
else:
|
||||
result = _safe_get(start_data, "docker", "start", default={})
|
||||
result = safe_get(start_data, "docker", "start", default={})
|
||||
response: dict[str, Any] = {
|
||||
"success": True,
|
||||
"action": "restart",
|
||||
@@ -342,12 +384,12 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
||||
|
||||
if action == "update_all":
|
||||
data = await make_graphql_request(MUTATIONS["update_all"])
|
||||
results = _safe_get(data, "docker", "updateAllContainers", default=[])
|
||||
results = safe_get(data, "docker", "updateAllContainers", default=[])
|
||||
return {"success": True, "action": "update_all", "containers": results}
|
||||
|
||||
# Single-container mutations
|
||||
if action in MUTATIONS:
|
||||
actual_id = await _resolve_container_id(container_id or "")
|
||||
actual_id = await _resolve_container_id(container_id or "", strict=True)
|
||||
op_context: dict[str, str] | None = (
|
||||
{"operation": action} if action in ("start", "stop") else None
|
||||
)
|
||||
@@ -382,10 +424,4 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_docker action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute docker/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Docker tool registered successfully")
|
||||
|
||||
@@ -7,6 +7,7 @@ connection testing, and subscription diagnostics.
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
@@ -19,9 +20,30 @@ from ..config.settings import (
|
||||
VERSION,
|
||||
)
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
def _safe_display_url(url: str | None) -> str | None:
|
||||
"""Return a redacted URL showing only scheme + host + port.
|
||||
|
||||
Strips path, query parameters, credentials, and fragments to avoid
|
||||
leaking internal network topology or embedded secrets (CWE-200).
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or "unknown"
|
||||
if parsed.port:
|
||||
return f"{parsed.scheme}://{host}:{parsed.port}"
|
||||
return f"{parsed.scheme}://{host}"
|
||||
except Exception:
|
||||
# If parsing fails, show nothing rather than leaking the raw URL
|
||||
return "<unparseable>"
|
||||
|
||||
|
||||
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
|
||||
|
||||
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
|
||||
|
||||
# Severity ordering: only upgrade, never downgrade
|
||||
@@ -53,12 +75,10 @@ def register_health_tool(mcp: FastMCP) -> None:
|
||||
test_connection - Quick connectivity test (just checks { online })
|
||||
diagnose - Subscription system diagnostics
|
||||
"""
|
||||
if action not in ("check", "test_connection", "diagnose"):
|
||||
raise ToolError(
|
||||
f"Invalid action '{action}'. Must be one of: check, test_connection, diagnose"
|
||||
)
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
try:
|
||||
with tool_error_handler("health", action, logger):
|
||||
logger.info(f"Executing unraid_health action={action}")
|
||||
|
||||
if action == "test_connection":
|
||||
@@ -79,12 +99,6 @@ def register_health_tool(mcp: FastMCP) -> None:
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_health action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute health/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Health tool registered successfully")
|
||||
|
||||
|
||||
@@ -111,7 +125,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
||||
overview { unread { alert warning total } }
|
||||
}
|
||||
docker {
|
||||
containers(skipCache: true) { id state status }
|
||||
containers { id state status }
|
||||
}
|
||||
}
|
||||
"""
|
||||
@@ -135,7 +149,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
||||
if info:
|
||||
health_info["unraid_system"] = {
|
||||
"status": "connected",
|
||||
"url": UNRAID_API_URL,
|
||||
"url": _safe_display_url(UNRAID_API_URL),
|
||||
"machine_id": info.get("machineId"),
|
||||
"version": info.get("versions", {}).get("unraid"),
|
||||
"uptime": info.get("os", {}).get("uptime"),
|
||||
@@ -215,6 +229,42 @@ async def _comprehensive_check() -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _analyze_subscription_status(
|
||||
status: dict[str, Any],
|
||||
) -> tuple[int, list[dict[str, Any]]]:
|
||||
"""Analyze subscription status dict, returning error count and connection issues.
|
||||
|
||||
This is the canonical implementation of subscription status analysis.
|
||||
TODO: subscriptions/diagnostics.py (lines 168-182) duplicates this logic.
|
||||
That module should be refactored to call this helper once file ownership
|
||||
allows cross-agent edits. See Code-H05.
|
||||
|
||||
Args:
|
||||
status: Dict of subscription name -> status info from get_subscription_status().
|
||||
|
||||
Returns:
|
||||
Tuple of (error_count, connection_issues_list).
|
||||
"""
|
||||
error_count = 0
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
conn_state = runtime.get("connection_state", "unknown")
|
||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
|
||||
error_count += 1
|
||||
if runtime.get("last_error"):
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
"state": conn_state,
|
||||
"error": runtime["last_error"],
|
||||
}
|
||||
)
|
||||
|
||||
return error_count, connection_issues
|
||||
|
||||
|
||||
async def _diagnose_subscriptions() -> dict[str, Any]:
|
||||
"""Import and run subscription diagnostics."""
|
||||
try:
|
||||
@@ -223,13 +273,10 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
|
||||
|
||||
await ensure_subscriptions_started()
|
||||
|
||||
status = subscription_manager.get_subscription_status()
|
||||
# This list is intentionally placed into the summary dict below and then
|
||||
# appended to in the loop — the mutable alias ensures both references
|
||||
# reflect the same data without a second pass.
|
||||
connection_issues: list[dict[str, Any]] = []
|
||||
status = await subscription_manager.get_subscription_status()
|
||||
error_count, connection_issues = _analyze_subscription_status(status)
|
||||
|
||||
diagnostic_info: dict[str, Any] = {
|
||||
return {
|
||||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||
"environment": {
|
||||
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
||||
@@ -241,27 +288,11 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
|
||||
"total_configured": len(subscription_manager.subscription_configs),
|
||||
"active_count": len(subscription_manager.active_subscriptions),
|
||||
"with_data": len(subscription_manager.resource_data),
|
||||
"in_error_state": 0,
|
||||
"in_error_state": error_count,
|
||||
"connection_issues": connection_issues,
|
||||
},
|
||||
}
|
||||
|
||||
for sub_name, sub_status in status.items():
|
||||
runtime = sub_status.get("runtime", {})
|
||||
conn_state = runtime.get("connection_state", "unknown")
|
||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
|
||||
diagnostic_info["summary"]["in_error_state"] += 1
|
||||
if runtime.get("last_error"):
|
||||
connection_issues.append(
|
||||
{
|
||||
"subscription": sub_name,
|
||||
"state": conn_state,
|
||||
"error": runtime["last_error"],
|
||||
}
|
||||
)
|
||||
|
||||
return diagnostic_info
|
||||
|
||||
except ImportError:
|
||||
return {
|
||||
"error": "Subscription modules not available",
|
||||
|
||||
@@ -10,7 +10,8 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
from ..core.utils import format_kb
|
||||
|
||||
|
||||
# Pre-built queries keyed by action name
|
||||
@@ -19,7 +20,7 @@ QUERIES: dict[str, str] = {
|
||||
query GetSystemInfo {
|
||||
info {
|
||||
os { platform distro release codename kernel arch hostname codepage logofile serial build uptime }
|
||||
cpu { manufacturer brand vendor family model stepping revision voltage speed speedmin speedmax threads cores processors socket cache flags }
|
||||
cpu { manufacturer brand vendor family model stepping revision voltage speed speedmin speedmax threads cores processors socket cache }
|
||||
memory {
|
||||
layout { bank type clockSpeed formFactor manufacturer partNum serialNum }
|
||||
}
|
||||
@@ -81,7 +82,6 @@ QUERIES: dict[str, str] = {
|
||||
shareAvahiEnabled safeMode startMode configValid configError joinStatus
|
||||
deviceCount flashGuid flashProduct flashVendor mdState mdVersion
|
||||
shareCount shareSmbCount shareNfsCount shareAfpCount shareMoverActive
|
||||
csrfToken
|
||||
}
|
||||
}
|
||||
""",
|
||||
@@ -156,6 +156,8 @@ QUERIES: dict[str, str] = {
|
||||
""",
|
||||
}
|
||||
|
||||
ALL_ACTIONS = set(QUERIES)
|
||||
|
||||
INFO_ACTIONS = Literal[
|
||||
"overview",
|
||||
"array",
|
||||
@@ -178,8 +180,12 @@ INFO_ACTIONS = Literal[
|
||||
"ups_config",
|
||||
]
|
||||
|
||||
assert set(QUERIES.keys()) == set(INFO_ACTIONS.__args__), (
|
||||
"QUERIES keys and INFO_ACTIONS are out of sync"
|
||||
if set(INFO_ACTIONS.__args__) != ALL_ACTIONS:
|
||||
_missing = ALL_ACTIONS - set(INFO_ACTIONS.__args__)
|
||||
_extra = set(INFO_ACTIONS.__args__) - ALL_ACTIONS
|
||||
raise RuntimeError(
|
||||
f"QUERIES keys and INFO_ACTIONS are out of sync. "
|
||||
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||
)
|
||||
|
||||
|
||||
@@ -189,17 +195,17 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
|
||||
if raw_info.get("os"):
|
||||
os_info = raw_info["os"]
|
||||
summary["os"] = (
|
||||
f"{os_info.get('distro', '')} {os_info.get('release', '')} "
|
||||
f"({os_info.get('platform', '')}, {os_info.get('arch', '')})"
|
||||
f"{os_info.get('distro') or 'unknown'} {os_info.get('release') or 'unknown'} "
|
||||
f"({os_info.get('platform') or 'unknown'}, {os_info.get('arch') or 'unknown'})"
|
||||
)
|
||||
summary["hostname"] = os_info.get("hostname")
|
||||
summary["hostname"] = os_info.get("hostname") or "unknown"
|
||||
summary["uptime"] = os_info.get("uptime")
|
||||
|
||||
if raw_info.get("cpu"):
|
||||
cpu = raw_info["cpu"]
|
||||
summary["cpu"] = (
|
||||
f"{cpu.get('manufacturer', '')} {cpu.get('brand', '')} "
|
||||
f"({cpu.get('cores', '?')} cores, {cpu.get('threads', '?')} threads)"
|
||||
f"{cpu.get('manufacturer') or 'unknown'} {cpu.get('brand') or 'unknown'} "
|
||||
f"({cpu.get('cores') or '?'} cores, {cpu.get('threads') or '?'} threads)"
|
||||
)
|
||||
|
||||
if raw_info.get("memory") and raw_info["memory"].get("layout"):
|
||||
@@ -207,10 +213,10 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
|
||||
summary["memory_layout_details"] = []
|
||||
for stick in mem_layout:
|
||||
summary["memory_layout_details"].append(
|
||||
f"Bank {stick.get('bank', '?')}: Type {stick.get('type', '?')}, "
|
||||
f"Speed {stick.get('clockSpeed', '?')}MHz, "
|
||||
f"Manufacturer: {stick.get('manufacturer', '?')}, "
|
||||
f"Part: {stick.get('partNum', '?')}"
|
||||
f"Bank {stick.get('bank') or '?'}: Type {stick.get('type') or '?'}, "
|
||||
f"Speed {stick.get('clockSpeed') or '?'}MHz, "
|
||||
f"Manufacturer: {stick.get('manufacturer') or '?'}, "
|
||||
f"Part: {stick.get('partNum') or '?'}"
|
||||
)
|
||||
summary["memory_summary"] = (
|
||||
"Stick layout details retrieved. Overall total/used/free memory stats "
|
||||
@@ -255,31 +261,14 @@ def _analyze_disk_health(disks: list[dict[str, Any]]) -> dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def _format_kb(k: Any) -> str:
|
||||
"""Format kilobyte values into human-readable sizes."""
|
||||
if k is None:
|
||||
return "N/A"
|
||||
try:
|
||||
k = int(k)
|
||||
except (ValueError, TypeError):
|
||||
return "N/A"
|
||||
if k >= 1024 * 1024 * 1024:
|
||||
return f"{k / (1024 * 1024 * 1024):.2f} TB"
|
||||
if k >= 1024 * 1024:
|
||||
return f"{k / (1024 * 1024):.2f} GB"
|
||||
if k >= 1024:
|
||||
return f"{k / 1024:.2f} MB"
|
||||
return f"{k} KB"
|
||||
|
||||
|
||||
def _process_array_status(raw: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process raw array data into summary + details."""
|
||||
summary: dict[str, Any] = {"state": raw.get("state")}
|
||||
if raw.get("capacity") and raw["capacity"].get("kilobytes"):
|
||||
kb = raw["capacity"]["kilobytes"]
|
||||
summary["capacity_total"] = _format_kb(kb.get("total"))
|
||||
summary["capacity_used"] = _format_kb(kb.get("used"))
|
||||
summary["capacity_free"] = _format_kb(kb.get("free"))
|
||||
summary["capacity_total"] = format_kb(kb.get("total"))
|
||||
summary["capacity_used"] = format_kb(kb.get("used"))
|
||||
summary["capacity_free"] = format_kb(kb.get("free"))
|
||||
|
||||
summary["num_parity_disks"] = len(raw.get("parities", []))
|
||||
summary["num_data_disks"] = len(raw.get("disks", []))
|
||||
@@ -345,8 +334,8 @@ def register_info_tool(mcp: FastMCP) -> None:
|
||||
ups_device - Single UPS device (requires device_id)
|
||||
ups_config - UPS configuration
|
||||
"""
|
||||
if action not in QUERIES:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {list(QUERIES.keys())}")
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
if action == "ups_device" and not device_id:
|
||||
raise ToolError("device_id is required for ups_device action")
|
||||
@@ -377,7 +366,7 @@ def register_info_tool(mcp: FastMCP) -> None:
|
||||
"ups_devices": ("upsDevices", "ups_devices"),
|
||||
}
|
||||
|
||||
try:
|
||||
with tool_error_handler("info", action, logger):
|
||||
logger.info(f"Executing unraid_info action={action}")
|
||||
data = await make_graphql_request(query, variables)
|
||||
|
||||
@@ -426,14 +415,8 @@ def register_info_tool(mcp: FastMCP) -> None:
|
||||
if action in list_actions:
|
||||
response_key, output_key = list_actions[action]
|
||||
items = data.get(response_key) or []
|
||||
return {output_key: list(items) if isinstance(items, list) else []}
|
||||
return {output_key: items}
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_info action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute info/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Info tool registered successfully")
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -45,6 +45,7 @@ MUTATIONS: dict[str, str] = {
|
||||
}
|
||||
|
||||
DESTRUCTIVE_ACTIONS = {"delete"}
|
||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||
|
||||
KEY_ACTIONS = Literal[
|
||||
"list",
|
||||
@@ -76,14 +77,13 @@ def register_keys_tool(mcp: FastMCP) -> None:
|
||||
update - Update an API key (requires key_id; optional name, roles)
|
||||
delete - Delete API keys (requires key_id, confirm=True)
|
||||
"""
|
||||
all_actions = set(QUERIES) | set(MUTATIONS)
|
||||
if action not in all_actions:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||
|
||||
try:
|
||||
with tool_error_handler("keys", action, logger):
|
||||
logger.info(f"Executing unraid_keys action={action}")
|
||||
|
||||
if action == "list":
|
||||
@@ -141,10 +141,4 @@ def register_keys_tool(mcp: FastMCP) -> None:
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_keys action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute keys/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Keys tool registered successfully")
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -76,6 +76,8 @@ MUTATIONS: dict[str, str] = {
|
||||
}
|
||||
|
||||
DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"}
|
||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||
_VALID_IMPORTANCE = {"ALERT", "WARNING", "NORMAL"}
|
||||
|
||||
NOTIFICATION_ACTIONS = Literal[
|
||||
"overview",
|
||||
@@ -120,16 +122,13 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
delete_archived - Delete all archived notifications (requires confirm=True)
|
||||
archive_all - Archive all notifications (optional importance filter)
|
||||
"""
|
||||
all_actions = {**QUERIES, **MUTATIONS}
|
||||
if action not in all_actions:
|
||||
raise ToolError(
|
||||
f"Invalid action '{action}'. Must be one of: {list(all_actions.keys())}"
|
||||
)
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||
|
||||
try:
|
||||
with tool_error_handler("notifications", action, logger):
|
||||
logger.info(f"Executing unraid_notifications action={action}")
|
||||
|
||||
if action == "overview":
|
||||
@@ -147,18 +146,29 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
filter_vars["importance"] = importance.upper()
|
||||
data = await make_graphql_request(QUERIES["list"], {"filter": filter_vars})
|
||||
notifications = data.get("notifications", {})
|
||||
result = notifications.get("list", [])
|
||||
return {"notifications": list(result) if isinstance(result, list) else []}
|
||||
return {"notifications": notifications.get("list", [])}
|
||||
|
||||
if action == "warnings":
|
||||
data = await make_graphql_request(QUERIES["warnings"])
|
||||
notifications = data.get("notifications", {})
|
||||
result = notifications.get("warningsAndAlerts", [])
|
||||
return {"warnings": list(result) if isinstance(result, list) else []}
|
||||
return {"warnings": notifications.get("warningsAndAlerts", [])}
|
||||
|
||||
if action == "create":
|
||||
if title is None or subject is None or description is None or importance is None:
|
||||
raise ToolError("create requires title, subject, description, and importance")
|
||||
if importance.upper() not in _VALID_IMPORTANCE:
|
||||
raise ToolError(
|
||||
f"importance must be one of: {', '.join(sorted(_VALID_IMPORTANCE))}. "
|
||||
f"Got: '{importance}'"
|
||||
)
|
||||
if len(title) > 200:
|
||||
raise ToolError(f"title must be at most 200 characters (got {len(title)})")
|
||||
if len(subject) > 500:
|
||||
raise ToolError(f"subject must be at most 500 characters (got {len(subject)})")
|
||||
if len(description) > 2000:
|
||||
raise ToolError(
|
||||
f"description must be at most 2000 characters (got {len(description)})"
|
||||
)
|
||||
input_data = {
|
||||
"title": title,
|
||||
"subject": subject,
|
||||
@@ -196,10 +206,4 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_notifications action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute notifications/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Notifications tool registered successfully")
|
||||
|
||||
@@ -4,13 +4,14 @@ Provides the `unraid_rclone` tool with 4 actions for managing
|
||||
cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -49,6 +50,51 @@ RCLONE_ACTIONS = Literal[
|
||||
"delete_remote",
|
||||
]
|
||||
|
||||
# Max config entries to prevent abuse
|
||||
_MAX_CONFIG_KEYS = 50
|
||||
# Pattern for suspicious key names (path traversal, shell metacharacters)
|
||||
_DANGEROUS_KEY_PATTERN = re.compile(r"[.]{2}|[/\\;|`$(){}]")
|
||||
# Max length for individual config values
|
||||
_MAX_VALUE_LENGTH = 4096
|
||||
|
||||
|
||||
def _validate_config_data(config_data: dict[str, Any]) -> dict[str, str]:
|
||||
"""Validate and sanitize rclone config_data before passing to GraphQL.
|
||||
|
||||
Ensures all keys and values are safe strings with no injection vectors.
|
||||
|
||||
Raises:
|
||||
ToolError: If config_data contains invalid keys or values
|
||||
"""
|
||||
if len(config_data) > _MAX_CONFIG_KEYS:
|
||||
raise ToolError(f"config_data has {len(config_data)} keys (max {_MAX_CONFIG_KEYS})")
|
||||
|
||||
validated: dict[str, str] = {}
|
||||
for key, value in config_data.items():
|
||||
if not isinstance(key, str) or not key.strip():
|
||||
raise ToolError(
|
||||
f"config_data keys must be non-empty strings, got: {type(key).__name__}"
|
||||
)
|
||||
if _DANGEROUS_KEY_PATTERN.search(key):
|
||||
raise ToolError(
|
||||
f"config_data key '{key}' contains disallowed characters "
|
||||
f"(path traversal or shell metacharacters)"
|
||||
)
|
||||
if not isinstance(value, (str, int, float, bool)):
|
||||
raise ToolError(
|
||||
f"config_data['{key}'] must be a string, number, or boolean, "
|
||||
f"got: {type(value).__name__}"
|
||||
)
|
||||
str_value = str(value)
|
||||
if len(str_value) > _MAX_VALUE_LENGTH:
|
||||
raise ToolError(
|
||||
f"config_data['{key}'] value exceeds max length "
|
||||
f"({len(str_value)} > {_MAX_VALUE_LENGTH})"
|
||||
)
|
||||
validated[key] = str_value
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def register_rclone_tool(mcp: FastMCP) -> None:
|
||||
"""Register the unraid_rclone tool with the FastMCP instance."""
|
||||
@@ -75,7 +121,7 @@ def register_rclone_tool(mcp: FastMCP) -> None:
|
||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||
|
||||
try:
|
||||
with tool_error_handler("rclone", action, logger):
|
||||
logger.info(f"Executing unraid_rclone action={action}")
|
||||
|
||||
if action == "list_remotes":
|
||||
@@ -96,9 +142,10 @@ def register_rclone_tool(mcp: FastMCP) -> None:
|
||||
if action == "create_remote":
|
||||
if name is None or provider_type is None or config_data is None:
|
||||
raise ToolError("create_remote requires name, provider_type, and config_data")
|
||||
validated_config = _validate_config_data(config_data)
|
||||
data = await make_graphql_request(
|
||||
MUTATIONS["create_remote"],
|
||||
{"input": {"name": name, "type": provider_type, "config": config_data}},
|
||||
{"input": {"name": name, "type": provider_type, "config": validated_config}},
|
||||
)
|
||||
remote = data.get("rclone", {}).get("createRCloneRemote")
|
||||
if not remote:
|
||||
@@ -127,10 +174,4 @@ def register_rclone_tool(mcp: FastMCP) -> None:
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_rclone action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute rclone/{action}: {e!s}") from e
|
||||
|
||||
logger.info("RClone tool registered successfully")
|
||||
|
||||
@@ -4,17 +4,19 @@ Provides the `unraid_storage` tool with 6 actions for shares, physical disks,
|
||||
unassigned devices, log files, and log content retrieval.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
import anyio
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import DISK_TIMEOUT, make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
from ..core.utils import format_bytes
|
||||
|
||||
|
||||
_ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/", "/mnt/")
|
||||
_MAX_TAIL_LINES = 10_000
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
"shares": """
|
||||
@@ -56,6 +58,8 @@ QUERIES: dict[str, str] = {
|
||||
""",
|
||||
}
|
||||
|
||||
ALL_ACTIONS = set(QUERIES)
|
||||
|
||||
STORAGE_ACTIONS = Literal[
|
||||
"shares",
|
||||
"disks",
|
||||
@@ -66,21 +70,6 @@ STORAGE_ACTIONS = Literal[
|
||||
]
|
||||
|
||||
|
||||
def format_bytes(bytes_value: int | None) -> str:
|
||||
"""Format byte values into human-readable sizes."""
|
||||
if bytes_value is None:
|
||||
return "N/A"
|
||||
try:
|
||||
value = float(int(bytes_value))
|
||||
except (ValueError, TypeError):
|
||||
return "N/A"
|
||||
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
||||
if value < 1024.0:
|
||||
return f"{value:.2f} {unit}"
|
||||
value /= 1024.0
|
||||
return f"{value:.2f} EB"
|
||||
|
||||
|
||||
def register_storage_tool(mcp: FastMCP) -> None:
|
||||
"""Register the unraid_storage tool with the FastMCP instance."""
|
||||
|
||||
@@ -101,17 +90,22 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
||||
log_files - List available log files
|
||||
logs - Retrieve log content (requires log_path, optional tail_lines)
|
||||
"""
|
||||
if action not in QUERIES:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {list(QUERIES.keys())}")
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
if action == "disk_details" and not disk_id:
|
||||
raise ToolError("disk_id is required for 'disk_details' action")
|
||||
|
||||
if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES:
|
||||
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||
|
||||
if action == "logs":
|
||||
if not log_path:
|
||||
raise ToolError("log_path is required for 'logs' action")
|
||||
# Resolve path to prevent traversal attacks (e.g. /var/log/../../etc/shadow)
|
||||
normalized = str(await anyio.Path(log_path).resolve())
|
||||
# Resolve path synchronously to prevent traversal attacks.
|
||||
# Using os.path.realpath instead of anyio.Path.resolve() because the
|
||||
# async variant blocks on NFS-mounted paths under /mnt/ (Perf-AI-1).
|
||||
normalized = os.path.realpath(log_path) # noqa: ASYNC240
|
||||
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
|
||||
raise ToolError(
|
||||
f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}. "
|
||||
@@ -128,17 +122,15 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
||||
elif action == "logs":
|
||||
variables = {"path": log_path, "lines": tail_lines}
|
||||
|
||||
try:
|
||||
with tool_error_handler("storage", action, logger):
|
||||
logger.info(f"Executing unraid_storage action={action}")
|
||||
data = await make_graphql_request(query, variables, custom_timeout=custom_timeout)
|
||||
|
||||
if action == "shares":
|
||||
shares = data.get("shares", [])
|
||||
return {"shares": list(shares) if isinstance(shares, list) else []}
|
||||
return {"shares": data.get("shares", [])}
|
||||
|
||||
if action == "disks":
|
||||
disks = data.get("disks", [])
|
||||
return {"disks": list(disks) if isinstance(disks, list) else []}
|
||||
return {"disks": data.get("disks", [])}
|
||||
|
||||
if action == "disk_details":
|
||||
raw = data.get("disk", {})
|
||||
@@ -159,22 +151,14 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
||||
return {"summary": summary, "details": raw}
|
||||
|
||||
if action == "unassigned":
|
||||
devices = data.get("unassignedDevices", [])
|
||||
return {"devices": list(devices) if isinstance(devices, list) else []}
|
||||
return {"devices": data.get("unassignedDevices", [])}
|
||||
|
||||
if action == "log_files":
|
||||
files = data.get("logFiles", [])
|
||||
return {"log_files": list(files) if isinstance(files, list) else []}
|
||||
return {"log_files": data.get("logFiles", [])}
|
||||
|
||||
if action == "logs":
|
||||
return dict(data.get("logFile") or {})
|
||||
|
||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_storage action={action}: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute storage/{action}: {e!s}") from e
|
||||
|
||||
logger.info("Storage tool registered successfully")
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -39,17 +39,11 @@ def register_users_tool(mcp: FastMCP) -> None:
|
||||
Note: Unraid API does not support user management operations (list, add, delete).
|
||||
"""
|
||||
if action not in ALL_ACTIONS:
|
||||
raise ToolError(f"Invalid action '{action}'. Must be: me")
|
||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||
|
||||
try:
|
||||
with tool_error_handler("users", action, logger):
|
||||
logger.info("Executing unraid_users action=me")
|
||||
data = await make_graphql_request(QUERIES["me"])
|
||||
return data.get("me") or {}
|
||||
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_users action=me: {e}", exc_info=True)
|
||||
raise ToolError(f"Failed to execute users/me: {e!s}") from e
|
||||
|
||||
logger.info("Users tool registered successfully")
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastmcp import FastMCP
|
||||
|
||||
from ..config.logging import logger
|
||||
from ..core.client import make_graphql_request
|
||||
from ..core.exceptions import ToolError
|
||||
from ..core.exceptions import ToolError, tool_error_handler
|
||||
|
||||
|
||||
QUERIES: dict[str, str] = {
|
||||
@@ -19,6 +19,13 @@ QUERIES: dict[str, str] = {
|
||||
vms { id domains { id name state uuid } }
|
||||
}
|
||||
""",
|
||||
# NOTE: The Unraid GraphQL API does not expose a single-VM query.
|
||||
# The details query is identical to list; client-side filtering is required.
|
||||
"details": """
|
||||
query ListVMs {
|
||||
vms { id domains { id name state uuid } }
|
||||
}
|
||||
""",
|
||||
}
|
||||
|
||||
MUTATIONS: dict[str, str] = {
|
||||
@@ -64,7 +71,7 @@ VM_ACTIONS = Literal[
|
||||
"reset",
|
||||
]
|
||||
|
||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"details"}
|
||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||
|
||||
|
||||
def register_vm_tool(mcp: FastMCP) -> None:
|
||||
@@ -98,20 +105,26 @@ def register_vm_tool(mcp: FastMCP) -> None:
|
||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||
|
||||
with tool_error_handler("vm", action, logger):
|
||||
try:
|
||||
logger.info(f"Executing unraid_vm action={action}")
|
||||
|
||||
if action in ("list", "details"):
|
||||
if action == "list":
|
||||
data = await make_graphql_request(QUERIES["list"])
|
||||
if data.get("vms"):
|
||||
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||
if isinstance(vms, dict):
|
||||
vms = [vms]
|
||||
|
||||
if action == "list":
|
||||
return {"vms": vms}
|
||||
return {"vms": []}
|
||||
|
||||
# details: find specific VM
|
||||
if action == "details":
|
||||
data = await make_graphql_request(QUERIES["details"])
|
||||
if not data.get("vms"):
|
||||
raise ToolError("No VM data returned from server")
|
||||
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||
if isinstance(vms, dict):
|
||||
vms = [vms]
|
||||
for vm in vms:
|
||||
if (
|
||||
vm.get("uuid") == vm_id
|
||||
@@ -121,9 +134,6 @@ def register_vm_tool(mcp: FastMCP) -> None:
|
||||
return dict(vm)
|
||||
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
|
||||
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
|
||||
if action == "details":
|
||||
raise ToolError("No VM data returned from server")
|
||||
return {"vms": []}
|
||||
|
||||
# Mutations
|
||||
if action in MUTATIONS:
|
||||
@@ -142,12 +152,10 @@ def register_vm_tool(mcp: FastMCP) -> None:
|
||||
except ToolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unraid_vm action={action}: {e}", exc_info=True)
|
||||
msg = str(e)
|
||||
if "VMs are not available" in msg:
|
||||
if "VMs are not available" in str(e):
|
||||
raise ToolError(
|
||||
"VMs not available on this server. Check VM support is enabled."
|
||||
) from e
|
||||
raise ToolError(f"Failed to execute vm/{action}: {msg}") from e
|
||||
raise
|
||||
|
||||
logger.info("VM tool registered successfully")
|
||||
|
||||
13
uv.lock
generated
13
uv.lock
generated
@@ -1706,15 +1706,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-pytz"
|
||||
version = "2025.2.0.20251108"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/40/ff/c047ddc68c803b46470a357454ef76f4acd8c1088f5cc4891cdd909bfcf6/types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb", size = 10961, upload-time = "2025-11-08T02:55:57.001Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c", size = 10116, upload-time = "2025-11-08T02:55:56.194Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.15.0"
|
||||
@@ -1745,7 +1736,6 @@ dependencies = [
|
||||
{ name = "fastmcp" },
|
||||
{ name = "httpx" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "pytz" },
|
||||
{ name = "rich" },
|
||||
{ name = "uvicorn", extra = ["standard"] },
|
||||
{ name = "websockets" },
|
||||
@@ -1762,7 +1752,6 @@ dev = [
|
||||
{ name = "ruff" },
|
||||
{ name = "twine" },
|
||||
{ name = "ty" },
|
||||
{ name = "types-pytz" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -1771,7 +1760,6 @@ requires-dist = [
|
||||
{ name = "fastmcp", specifier = ">=2.14.5" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "python-dotenv", specifier = ">=1.1.1" },
|
||||
{ name = "pytz", specifier = ">=2025.2" },
|
||||
{ name = "rich", specifier = ">=14.1.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.35.0" },
|
||||
{ name = "websockets", specifier = ">=15.0.1" },
|
||||
@@ -1788,7 +1776,6 @@ dev = [
|
||||
{ name = "ruff", specifier = ">=0.12.8" },
|
||||
{ name = "twine", specifier = ">=6.0.1" },
|
||||
{ name = "ty", specifier = ">=0.0.15" },
|
||||
{ name = "types-pytz", specifier = ">=2025.2.0.20250809" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user