feat: add 5 notification mutations + comprehensive refactors from PR review

New notification actions (archive_many, create_unique, unarchive_many,
unarchive_all, recalculate) bring unraid_notifications to 14 actions.

Also includes continuation of CodeRabbit/PR review fixes:
- Remove redundant try-except in virtualization.py (silent failure fix)
- Add QueryCache protocol with get/put/invalidate_all to core/client.py
- Refactor subscriptions (manager, diagnostics, resources, utils)
- Update config (logging, settings) for improved structure
- Expand test coverage: http_layer, safety guards, schema validation
- Minor cleanups: array, docker, health, keys tools

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jacob Magar
2026-03-13 01:54:55 -04:00
parent 06f18f32fc
commit 60defc35ca
27 changed files with 2508 additions and 423 deletions

View File

@@ -22,7 +22,7 @@ from ..core.exceptions import ToolError
from ..core.utils import safe_display_url
from .manager import subscription_manager
from .resources import ensure_subscriptions_started
from .utils import build_ws_ssl_context, build_ws_url
from .utils import _analyze_subscription_status, build_ws_ssl_context, build_ws_url
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
@@ -187,8 +187,10 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
# Get comprehensive status
status = await subscription_manager.get_subscription_status()
# Initialize connection issues list with proper type
connection_issues: list[dict[str, Any]] = []
# Analyze connection issues and error counts via the shared helper.
# This ensures "invalid_uri" and all other error states are counted
# consistently with the health tool's _diagnose_subscriptions path.
error_count, connection_issues = _analyze_subscription_status(status)
# Add environment info with explicit typing
diagnostic_info: dict[str, Any] = {
@@ -210,7 +212,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
),
"active_count": len(subscription_manager.active_subscriptions),
"with_data": len(subscription_manager.resource_data),
"in_error_state": 0,
"in_error_state": error_count,
"connection_issues": connection_issues,
},
}
@@ -219,23 +221,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
with contextlib.suppress(ValueError):
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
# Analyze issues
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
connection_state = runtime.get("connection_state", "unknown")
if connection_state in ["error", "auth_failed", "timeout", "max_retries_exceeded"]:
diagnostic_info["summary"]["in_error_state"] += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": connection_state,
"error": runtime["last_error"],
}
)
# Add troubleshooting recommendations
recommendations: list[str] = []

View File

@@ -45,7 +45,7 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
elif (
key == "content"
and isinstance(value, str)
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
and len(value) > _MAX_RESOURCE_DATA_BYTES # fast pre-check on char count
):
lines = value.splitlines()
original_line_count = len(lines)
@@ -54,19 +54,15 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
# Enforce byte cap while preserving whole-line boundaries where possible.
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
lines = lines[1:]
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
# Last resort: if a single line still exceeds cap, hard-cap bytes.
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
"utf-8", errors="ignore"
)
# Encode once and slice bytes instead of O(n²) line-trim loop
encoded = truncated.encode("utf-8", errors="replace")
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
# Strip partial first line that may have been cut mid-character
nl_pos = truncated.find("\n")
if nl_pos != -1:
truncated = truncated[nl_pos + 1 :]
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
@@ -202,6 +198,16 @@ class SubscriptionManager:
else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
async def stop_all(self) -> None:
"""Stop all active subscriptions (called during server shutdown)."""
subscription_names = list(self.active_subscriptions.keys())
for name in subscription_names:
try:
await self.stop_subscription(name)
except Exception as e:
logger.error(f"[SHUTDOWN] Error stopping subscription '{name}': {e}", exc_info=True)
logger.info(f"[SHUTDOWN] Stopped {len(subscription_names)} subscription(s)")
async def _subscription_loop(
self, subscription_name: str, query: str, variables: dict[str, Any] | None
) -> None:
@@ -512,9 +518,11 @@ class SubscriptionManager:
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
)
# Calculate backoff delay
retry_delay = min(retry_delay * 1.5, max_retry_delay)
# Only escalate backoff when connection was NOT stable
retry_delay = min(retry_delay * 1.5, max_retry_delay)
else:
# No connection was established — escalate backoff
retry_delay = min(retry_delay * 1.5, max_retry_delay)
logger.info(
f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds..."
)

View File

@@ -4,8 +4,10 @@ This module defines MCP resources that bridge between the subscription manager
and the MCP protocol, providing fallback queries when subscription data is unavailable.
"""
import asyncio
import json
import os
from typing import Final
import anyio
from fastmcp import FastMCP
@@ -16,22 +18,29 @@ from .manager import subscription_manager
# Global flag to track subscription startup
_subscriptions_started = False
_startup_lock: Final[asyncio.Lock] = asyncio.Lock()
async def ensure_subscriptions_started() -> None:
"""Ensure subscriptions are started, called from async context."""
global _subscriptions_started
# Fast-path: skip lock if already started
if _subscriptions_started:
return
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
try:
await autostart_subscriptions()
_subscriptions_started = True
logger.info("[STARTUP] Subscriptions started successfully")
except Exception as e:
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
# Slow-path: acquire lock for initialization (double-checked locking)
async with _startup_lock:
if _subscriptions_started:
return
logger.info("[STARTUP] First async operation detected, starting subscriptions...")
try:
await autostart_subscriptions()
_subscriptions_started = True
logger.info("[STARTUP] Subscriptions started successfully")
except Exception as e:
logger.error(f"[STARTUP] Failed to start subscriptions: {e}", exc_info=True)
async def autostart_subscriptions() -> None:

View File

@@ -1,6 +1,7 @@
"""Shared utilities for the subscription system."""
import ssl as _ssl
from typing import Any
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL
@@ -52,3 +53,37 @@ def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
ctx.check_hostname = False
ctx.verify_mode = _ssl.CERT_NONE
return ctx
def _analyze_subscription_status(
status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]:
"""Analyze subscription status dict, returning error count and connection issues.
This is the canonical, shared implementation used by both the health tool
and the subscription diagnostics tool.
Args:
status: Dict of subscription name -> status info from get_subscription_status().
Returns:
Tuple of (error_count, connection_issues_list).
"""
error_count = 0
connection_issues: list[dict[str, Any]] = []
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded", "invalid_uri"):
error_count += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": conn_state,
"error": runtime["last_error"],
}
)
return error_count, connection_issues