fix: address 18 PR review comments (threads 1-18)

Threads 1, 2, 3 — test hygiene:
- Move elicit_and_configure/elicit_reset_confirmation to module-level imports
  in unraid.py so tests can patch at unraid_mcp.tools.unraid.* (thread 2)
- Add return type annotations to _make_tool() in test_customization.py (thread 1)
- Replace unused _mock_ensure_started fixture params with @usefixtures (thread 3)

Thread 4 — remove dead 'connect' subaction from _SYSTEM_QUERIES; the subaction
was always rejected with a ToolError, creating an inconsistent contract.

Thread 5 — centralize two inline "query { online }" strings by reusing
_SYSTEM_QUERIES["online"]; add _DOCKER_QUERIES["_resolve"] for container-name
resolution instead of an inline query literal.

Threads 14, 15, 16, 17, 18 — test improvements:
- test-tools.sh: reword header to "broad non-destructive smoke coverage" (t14)
- test-tools.sh: add _json_payload() helper using jq --arg for safe JSON
  construction; replace all printf-based payloads (thread 15)
- test_input_validation.py: add return type annotations to _make_tool and all
  nested _run_test coroutines (thread 16)
- test_query_validation.py: extract _all_domain_dicts() shared helper to
  eliminate the duplicate 22-item registry (thread 17)
- test_query_validation.py: tighten regression threshold from 50 → 90 (thread 18)
This commit is contained in:
Jacob Magar
2026-03-16 10:01:12 -04:00
parent 884319ab11
commit cf9449a15d
10 changed files with 252 additions and 177 deletions

View File

@@ -7,6 +7,11 @@ separate modules for configuration, core functionality, subscriptions, and tools
import sys
from fastmcp import FastMCP
from fastmcp.server.middleware.caching import CallToolSettings, ResponseCachingMiddleware
from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.middleware.rate_limiting import SlidingWindowRateLimitingMiddleware
from fastmcp.server.middleware.response_limiting import ResponseLimitingMiddleware
from .config.logging import logger
from .config.settings import (
@@ -22,11 +27,59 @@ from .subscriptions.resources import register_subscription_resources
from .tools.unraid import register_unraid_tool
# Middleware chain order matters — each layer wraps everything inside it:
# logging → error_handling → rate_limiter → response_limiter → cache → tool
# 1. Log every tools/call and resources/read: method, duration, errors.
# Outermost so it captures errors after they've been converted by error_handling.
_logging_middleware = LoggingMiddleware(
logger=logger,
methods=["tools/call", "resources/read"],
)
# 2. Catch any unhandled exceptions and convert to proper MCP errors.
# Tracks error_counts per (exception_type:method) for health diagnose.
error_middleware = ErrorHandlingMiddleware(
logger=logger,
include_traceback=True,
)
# 3. Unraid API rate limit: 100 requests per 10 seconds.
# Use a sliding window that stays comfortably under that cap.
_rate_limiter = SlidingWindowRateLimitingMiddleware(max_requests=90, window_minutes=1)
# 4. Cap tool responses at 512 KB to protect the client context window.
# Oversized responses are truncated with a clear suffix rather than erroring.
_response_limiter = ResponseLimitingMiddleware(max_size=512_000)
# 5. Cache tool calls in-memory (MemoryStore default — no extra deps).
# Short 30 s TTL absorbs burst duplicate requests while keeping data fresh.
# Destructive calls won't hit the cache in practice (unique confirm=True + IDs).
cache_middleware = ResponseCachingMiddleware(
call_tool_settings=CallToolSettings(
ttl=30,
included_tools=["unraid"],
),
# Disable caching for list/resource/prompt — those are cheap.
list_tools_settings={"enabled": False},
list_resources_settings={"enabled": False},
list_prompts_settings={"enabled": False},
read_resource_settings={"enabled": False},
get_prompt_settings={"enabled": False},
)
# Initialize FastMCP instance
mcp = FastMCP(
name="Unraid MCP Server",
instructions="Provides tools to interact with an Unraid server's GraphQL API.",
version=VERSION,
middleware=[
_logging_middleware,
error_middleware,
_rate_limiter,
_response_limiter,
cache_middleware,
],
)
# Note: SubscriptionManager singleton is defined in subscriptions/manager.py

View File

@@ -34,6 +34,7 @@ from ..config.logging import logger
from ..core.client import DISK_TIMEOUT, make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler
from ..core.guards import gate_destructive_action
from ..core.setup import elicit_and_configure, elicit_reset_confirmation
from ..core.utils import format_bytes, format_kb, safe_get
@@ -78,7 +79,6 @@ _SYSTEM_QUERIES: dict[str, str] = {
registration { id type keyFile { location } state expiration updateExpiration }
}
""",
"connect": "query GetConnectSettings { connect { id dynamicRemoteAccess { enabledType runningType error } } }",
"variables": """
query GetSelectiveUnraidVariables {
vars {
@@ -150,10 +150,6 @@ async def _handle_system(subaction: str, device_id: str | None) -> dict[str, Any
f"Invalid subaction '{subaction}' for system. Must be one of: {sorted(_SYSTEM_SUBACTIONS)}"
)
if subaction == "connect":
raise ToolError(
"The 'connect' query is not available on this Unraid API version. Use 'settings' instead."
)
if subaction == "ups_device" and not device_id:
raise ToolError("device_id is required for system/ups_device")
@@ -302,14 +298,13 @@ async def _handle_health(subaction: str, ctx: Context | None) -> dict[str, Any]
CREDENTIALS_ENV_PATH,
UNRAID_API_URL,
)
from ..core.setup import elicit_and_configure, elicit_reset_confirmation
from ..core.utils import safe_display_url
from ..subscriptions.utils import _analyze_subscription_status
if subaction == "setup":
if CREDENTIALS_ENV_PATH.exists():
try:
await make_graphql_request("query { online }")
await make_graphql_request(_SYSTEM_QUERIES["online"])
connection_ok = True
except Exception:
connection_ok = False
@@ -343,7 +338,7 @@ async def _handle_health(subaction: str, ctx: Context | None) -> dict[str, Any]
if subaction == "test_connection":
start = time.time()
data = await make_graphql_request("query { online }")
data = await make_graphql_request(_SYSTEM_QUERIES["online"])
latency = round((time.time() - start) * 1000, 2)
return {"status": "connected", "online": data.get("online"), "latency_ms": latency}
@@ -351,12 +346,14 @@ async def _handle_health(subaction: str, ctx: Context | None) -> dict[str, Any]
return await _comprehensive_health_check()
if subaction == "diagnose":
from ..server import cache_middleware, error_middleware
from ..subscriptions.manager import subscription_manager
from ..subscriptions.resources import ensure_subscriptions_started
await ensure_subscriptions_started()
status = await subscription_manager.get_subscription_status()
error_count, connection_issues = _analyze_subscription_status(status)
cache_stats = cache_middleware.statistics()
return {
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
"environment": {
@@ -372,6 +369,16 @@ async def _handle_health(subaction: str, ctx: Context | None) -> dict[str, Any]
"in_error_state": error_count,
"connection_issues": connection_issues,
},
"cache": {
"call_tool": {
"hits": cache_stats.call_tool.get.hit,
"misses": cache_stats.call_tool.get.miss,
"puts": cache_stats.call_tool.put.total,
}
if cache_stats.call_tool
else {"hits": 0, "misses": 0, "puts": 0},
},
"errors": error_middleware.get_error_stats(),
}
raise ToolError(f"Unhandled health subaction '{subaction}' — this is a bug")
@@ -731,6 +738,7 @@ _DOCKER_QUERIES: dict[str, str] = {
"details": "query GetContainerDetails { docker { containers(skipCache: false) { id names image imageId command created ports { ip privatePort publicPort type } sizeRootFs labels state status hostConfig { networkMode } networkSettings mounts autoStart } } }",
"networks": "query GetDockerNetworks { docker { networks { id name driver scope } } }",
"network_details": "query GetDockerNetwork { docker { networks { id name driver scope enableIPv6 internal attachable containers options labels } } }",
"_resolve": "query ResolveContainerID { docker { containers(skipCache: true) { id names } } }",
}
_DOCKER_MUTATIONS: dict[str, str] = {
@@ -767,9 +775,7 @@ def _find_container(
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
if _DOCKER_ID_PATTERN.match(container_id):
return container_id
data = await make_graphql_request(
"query { docker { containers(skipCache: true) { id names } } }"
)
data = await make_graphql_request(_DOCKER_QUERIES["_resolve"])
containers = safe_get(data, "docker", "containers", default=[])
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
id_lower = container_id.lower()