fix: apply all PR review agent findings (silent failures, type safety, test gaps)

Addresses issues found by 4 parallel review agents (code-reviewer,
silent-failure-hunter, type-design-analyzer, pr-test-analyzer).

Source fixes:
- core/utils.py: add public safe_display_url() (moved from tools/health.py)
- core/client.py: rename _redact_sensitive → redact_sensitive (public API)
- core/types.py: add SubscriptionData.__post_init__ for tz-aware datetime
  enforcement; remove 6 unused type aliases (SystemHealth, APIResponse, etc.)
- subscriptions/manager.py: add exc_info=True to both except-Exception blocks;
  add except ValueError break-on-config-error before retry loop; import
  redact_sensitive by new public name
- subscriptions/resources.py: re-raise in autostart_subscriptions() so
  ensure_subscriptions_started() doesn't permanently set _subscriptions_started
- subscriptions/diagnostics.py: except ToolError: raise before broad except;
  use safe_display_url() instead of raw URL slice
- tools/health.py: move _safe_display_url to core/utils; add exc_info=True;
  raise ToolError (not return dict) on ImportError
- tools/info.py: use get_args(INFO_ACTIONS) instead of INFO_ACTIONS.__args__
- tools/{array,docker,keys,notifications,rclone,storage,virtualization}.py:
  add Literal-vs-ALL_ACTIONS sync check at import time

Test fixes:
- test_health.py: import safe_display_url from core.utils; update
  test_diagnose_import_error_internal to expect ToolError (not error dict)
- test_storage.py: add 3 safe_get tests for zero/False/empty-string values
- test_subscription_manager.py: add TestCapLogContentSingleMassiveLine (2 tests)
- test_client.py: rename _redact_sensitive → redact_sensitive; add tests for
  new sensitive keys and is_cacheable explicit-keyword form
This commit is contained in:
Jacob Magar
2026-02-19 02:23:04 -05:00
parent 348f4149a5
commit 1751bc2984
28 changed files with 354 additions and 187 deletions

View File

@@ -16,6 +16,7 @@ import websockets.exceptions
from unraid_mcp.subscriptions.manager import SubscriptionManager from unraid_mcp.subscriptions.manager import SubscriptionManager
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration

View File

@@ -12,9 +12,9 @@ from unraid_mcp.core.client import (
DISK_TIMEOUT, DISK_TIMEOUT,
_QueryCache, _QueryCache,
_RateLimiter, _RateLimiter,
_redact_sensitive,
is_idempotent_error, is_idempotent_error,
make_graphql_request, make_graphql_request,
redact_sensitive,
) )
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
@@ -60,7 +60,7 @@ class TestIsIdempotentError:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _redact_sensitive # redact_sensitive
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -69,36 +69,36 @@ class TestRedactSensitive:
def test_flat_dict(self) -> None: def test_flat_dict(self) -> None:
data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"} data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"}
result = _redact_sensitive(data) result = redact_sensitive(data)
assert result["username"] == "admin" assert result["username"] == "admin"
assert result["password"] == "***" assert result["password"] == "***"
assert result["host"] == "10.0.0.1" assert result["host"] == "10.0.0.1"
def test_nested_dict(self) -> None: def test_nested_dict(self) -> None:
data = {"config": {"apiKey": "abc123", "url": "http://host"}} data = {"config": {"apiKey": "abc123", "url": "http://host"}}
result = _redact_sensitive(data) result = redact_sensitive(data)
assert result["config"]["apiKey"] == "***" assert result["config"]["apiKey"] == "***"
assert result["config"]["url"] == "http://host" assert result["config"]["url"] == "http://host"
def test_list_of_dicts(self) -> None: def test_list_of_dicts(self) -> None:
data = [{"token": "t1"}, {"name": "safe"}] data = [{"token": "t1"}, {"name": "safe"}]
result = _redact_sensitive(data) result = redact_sensitive(data)
assert result[0]["token"] == "***" assert result[0]["token"] == "***"
assert result[1]["name"] == "safe" assert result[1]["name"] == "safe"
def test_deeply_nested(self) -> None: def test_deeply_nested(self) -> None:
data = {"a": {"b": {"c": {"secret": "deep"}}}} data = {"a": {"b": {"c": {"secret": "deep"}}}}
result = _redact_sensitive(data) result = redact_sensitive(data)
assert result["a"]["b"]["c"]["secret"] == "***" assert result["a"]["b"]["c"]["secret"] == "***"
def test_non_dict_passthrough(self) -> None: def test_non_dict_passthrough(self) -> None:
assert _redact_sensitive("plain_string") == "plain_string" assert redact_sensitive("plain_string") == "plain_string"
assert _redact_sensitive(42) == 42 assert redact_sensitive(42) == 42
assert _redact_sensitive(None) is None assert redact_sensitive(None) is None
def test_case_insensitive_keys(self) -> None: def test_case_insensitive_keys(self) -> None:
data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"} data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"}
result = _redact_sensitive(data) result = redact_sensitive(data)
for v in result.values(): for v in result.values():
assert v == "***" assert v == "***"
@@ -112,7 +112,7 @@ class TestRedactSensitive:
"username": "safe", "username": "safe",
"host": "safe", "host": "safe",
} }
result = _redact_sensitive(data) result = redact_sensitive(data)
assert result["user_password"] == "***" assert result["user_password"] == "***"
assert result["api_key_value"] == "***" assert result["api_key_value"] == "***"
assert result["auth_token_expiry"] == "***" assert result["auth_token_expiry"] == "***"
@@ -122,12 +122,26 @@ class TestRedactSensitive:
def test_mixed_list_content(self) -> None: def test_mixed_list_content(self) -> None:
data = [{"key": "val"}, "string", 123, [{"token": "inner"}]] data = [{"key": "val"}, "string", 123, [{"token": "inner"}]]
result = _redact_sensitive(data) result = redact_sensitive(data)
assert result[0]["key"] == "***" assert result[0]["key"] == "***"
assert result[1] == "string" assert result[1] == "string"
assert result[2] == 123 assert result[2] == 123
assert result[3][0]["token"] == "***" assert result[3][0]["token"] == "***"
def test_new_sensitive_keys_are_redacted(self) -> None:
"""PR-added keys: authorization, cookie, session, credential, passphrase, jwt."""
data = {
"authorization": "Bearer token123",
"cookie": "session=abc",
"jwt": "eyJ...",
"credential": "secret_cred",
"passphrase": "hunter2",
"session": "sess_id",
}
result = redact_sensitive(data)
for key, val in result.items():
assert val == "***", f"Key '{key}' was not redacted"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Timeout constants # Timeout constants
@@ -347,7 +361,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="invalid response.*not valid JSON"), pytest.raises(ToolError, match=r"invalid response.*not valid JSON"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -481,7 +495,7 @@ class TestRateLimiter:
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0) limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
initial = limiter.tokens initial = limiter.tokens
await limiter.acquire() await limiter.acquire()
assert limiter.tokens == initial - 1 assert limiter.tokens == pytest.approx(initial - 1, abs=1e-3)
async def test_acquire_succeeds_when_tokens_available(self) -> None: async def test_acquire_succeeds_when_tokens_available(self) -> None:
limiter = _RateLimiter(max_tokens=5, refill_rate=1.0) limiter = _RateLimiter(max_tokens=5, refill_rate=1.0)
@@ -596,6 +610,15 @@ class TestQueryCache:
"""Queries that start with 'mutation' after whitespace are not cacheable.""" """Queries that start with 'mutation' after whitespace are not cacheable."""
assert _QueryCache.is_cacheable(" mutation { ... }") is False assert _QueryCache.is_cacheable(" mutation { ... }") is False
def test_is_cacheable_with_explicit_query_keyword(self) -> None:
"""Operation names after explicit 'query' keyword must be recognized."""
assert _QueryCache.is_cacheable("query GetNetworkConfig { network { name } }") is True
assert _QueryCache.is_cacheable("query GetOwner { owner { name } }") is True
def test_is_cacheable_anonymous_query_returns_false(self) -> None:
"""Anonymous 'query { ... }' has no operation name — must not be cached."""
assert _QueryCache.is_cacheable("query { network { name } }") is False
def test_expired_entry_removed_from_store(self) -> None: def test_expired_entry_removed_from_store(self) -> None:
"""Accessing an expired entry should remove it from the internal store.""" """Accessing an expired entry should remove it from the internal store."""
cache = _QueryCache() cache = _QueryCache()

View File

@@ -80,6 +80,14 @@ class TestDockerValidation:
with pytest.raises(ToolError, match="network_id"): with pytest.raises(ToolError, match="network_id"):
await tool_fn(action="network_details") await tool_fn(action="network_details")
async def test_non_logs_action_ignores_tail_lines_validation(
self, _mock_graphql: AsyncMock
) -> None:
_mock_graphql.return_value = {"docker": {"containers": []}}
tool_fn = _make_tool()
result = await tool_fn(action="list", tail_lines=0)
assert result["containers"] == []
class TestDockerActions: class TestDockerActions:
async def test_list(self, _mock_graphql: AsyncMock) -> None: async def test_list(self, _mock_graphql: AsyncMock) -> None:
@@ -224,9 +232,22 @@ class TestDockerActions:
async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("unexpected failure") _mock_graphql.side_effect = RuntimeError("unexpected failure")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="unexpected failure"): with pytest.raises(ToolError, match="Failed to execute docker/list"):
await tool_fn(action="list") await tool_fn(action="list")
async def test_short_id_prefix_ambiguous_rejected(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {
"docker": {
"containers": [
{"id": "abcdef1234560000000000000000000000000000000000000000000000000000:local", "names": ["plex"]},
{"id": "abcdef1234561111111111111111111111111111111111111111111111111111:local", "names": ["sonarr"]},
]
}
}
tool_fn = _make_tool()
with pytest.raises(ToolError, match="ambiguous"):
await tool_fn(action="logs", container_id="abcdef123456")
class TestDockerMutationFailures: class TestDockerMutationFailures:
"""Tests for mutation responses that indicate failure or unexpected shapes.""" """Tests for mutation responses that indicate failure or unexpected shapes."""

View File

@@ -7,7 +7,7 @@ import pytest
from conftest import make_tool_fn from conftest import make_tool_fn
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
from unraid_mcp.tools.health import _safe_display_url from unraid_mcp.core.utils import safe_display_url
@pytest.fixture @pytest.fixture
@@ -100,7 +100,7 @@ class TestHealthActions:
"unraid_mcp.tools.health._diagnose_subscriptions", "unraid_mcp.tools.health._diagnose_subscriptions",
side_effect=RuntimeError("broken"), side_effect=RuntimeError("broken"),
), ),
pytest.raises(ToolError, match="broken"), pytest.raises(ToolError, match="Failed to execute health/diagnose"),
): ):
await tool_fn(action="diagnose") await tool_fn(action="diagnose")
@@ -115,7 +115,7 @@ class TestHealthActions:
assert "cpu_sub" in result assert "cpu_sub" in result
async def test_diagnose_import_error_internal(self) -> None: async def test_diagnose_import_error_internal(self) -> None:
"""_diagnose_subscriptions catches ImportError and returns error dict.""" """_diagnose_subscriptions raises ToolError when subscription modules are unavailable."""
import sys import sys
from unraid_mcp.tools.health import _diagnose_subscriptions from unraid_mcp.tools.health import _diagnose_subscriptions
@@ -127,16 +127,18 @@ class TestHealthActions:
try: try:
# Replace the modules with objects that raise ImportError on access # Replace the modules with objects that raise ImportError on access
with patch.dict( with (
patch.dict(
sys.modules, sys.modules,
{ {
"unraid_mcp.subscriptions": None, "unraid_mcp.subscriptions": None,
"unraid_mcp.subscriptions.manager": None, "unraid_mcp.subscriptions.manager": None,
"unraid_mcp.subscriptions.resources": None, "unraid_mcp.subscriptions.resources": None,
}, },
),
pytest.raises(ToolError, match="Subscription modules not available"),
): ):
result = await _diagnose_subscriptions() await _diagnose_subscriptions()
assert "error" in result
finally: finally:
# Restore cached modules # Restore cached modules
sys.modules.update(cached) sys.modules.update(cached)
@@ -148,47 +150,47 @@ class TestHealthActions:
class TestSafeDisplayUrl: class TestSafeDisplayUrl:
"""Verify that _safe_display_url strips credentials/path and preserves scheme+host+port.""" """Verify that safe_display_url strips credentials/path and preserves scheme+host+port."""
def test_none_returns_none(self) -> None: def test_none_returns_none(self) -> None:
assert _safe_display_url(None) is None assert safe_display_url(None) is None
def test_empty_string_returns_none(self) -> None: def test_empty_string_returns_none(self) -> None:
assert _safe_display_url("") is None assert safe_display_url("") is None
def test_simple_url_scheme_and_host(self) -> None: def test_simple_url_scheme_and_host(self) -> None:
assert _safe_display_url("https://unraid.local/graphql") == "https://unraid.local" assert safe_display_url("https://unraid.local/graphql") == "https://unraid.local"
def test_preserves_port(self) -> None: def test_preserves_port(self) -> None:
assert _safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337" assert safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337"
def test_strips_path(self) -> None: def test_strips_path(self) -> None:
result = _safe_display_url("http://unraid.local/some/deep/path?query=1") result = safe_display_url("http://unraid.local/some/deep/path?query=1")
assert "path" not in result assert "path" not in result
assert "query" not in result assert "query" not in result
def test_strips_credentials(self) -> None: def test_strips_credentials(self) -> None:
result = _safe_display_url("https://user:password@unraid.local/graphql") result = safe_display_url("https://user:password@unraid.local/graphql")
assert "user" not in result assert "user" not in result
assert "password" not in result assert "password" not in result
assert result == "https://unraid.local" assert result == "https://unraid.local"
def test_strips_query_params(self) -> None: def test_strips_query_params(self) -> None:
result = _safe_display_url("http://host.local?token=abc&key=xyz") result = safe_display_url("http://host.local?token=abc&key=xyz")
assert "token" not in result assert "token" not in result
assert "abc" not in result assert "abc" not in result
def test_http_scheme_preserved(self) -> None: def test_http_scheme_preserved(self) -> None:
result = _safe_display_url("http://10.0.0.1:8080/api") result = safe_display_url("http://10.0.0.1:8080/api")
assert result == "http://10.0.0.1:8080" assert result == "http://10.0.0.1:8080"
def test_tailscale_url(self) -> None: def test_tailscale_url(self) -> None:
result = _safe_display_url("https://100.118.209.1:31337/graphql") result = safe_display_url("https://100.118.209.1:31337/graphql")
assert result == "https://100.118.209.1:31337" assert result == "https://100.118.209.1:31337"
def test_malformed_ipv6_url_returns_unparseable(self) -> None: def test_malformed_ipv6_url_returns_unparseable(self) -> None:
"""Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError.""" """Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError."""
# urlparse("https://[invalid") parses without error, but accessing .hostname # urlparse("https://[invalid") parses without error, but accessing .hostname
# raises ValueError: Invalid IPv6 URL — this triggers the except branch. # raises ValueError: Invalid IPv6 URL — this triggers the except branch.
result = _safe_display_url("https://[invalid") result = safe_display_url("https://[invalid")
assert result == "<unparseable>" assert result == "<unparseable>"

View File

@@ -186,7 +186,7 @@ class TestUnraidInfoTool:
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("unexpected") _mock_graphql.side_effect = RuntimeError("unexpected")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="unexpected"): with pytest.raises(ToolError, match="Failed to execute info/online"):
await tool_fn(action="online") await tool_fn(action="online")
async def test_metrics(self, _mock_graphql: AsyncMock) -> None: async def test_metrics(self, _mock_graphql: AsyncMock) -> None:
@@ -201,6 +201,7 @@ class TestUnraidInfoTool:
_mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]} _mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]}
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="services") result = await tool_fn(action="services")
assert "services" in result
assert len(result["services"]) == 1 assert len(result["services"]) == 1
assert result["services"][0]["name"] == "docker" assert result["services"][0]["name"] == "docker"
@@ -225,6 +226,7 @@ class TestUnraidInfoTool:
} }
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="servers") result = await tool_fn(action="servers")
assert "servers" in result
assert len(result["servers"]) == 1 assert len(result["servers"]) == 1
assert result["servers"][0]["name"] == "tower" assert result["servers"][0]["name"] == "tower"
@@ -248,6 +250,7 @@ class TestUnraidInfoTool:
} }
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="ups_devices") result = await tool_fn(action="ups_devices")
assert "ups_devices" in result
assert len(result["ups_devices"]) == 1 assert len(result["ups_devices"]) == 1
assert result["ups_devices"][0]["model"] == "APC" assert result["ups_devices"][0]["model"] == "APC"

View File

@@ -19,7 +19,6 @@ def _make_tool():
return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone") return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone")
@pytest.mark.usefixtures("_mock_graphql")
class TestRcloneValidation: class TestRcloneValidation:
async def test_delete_requires_confirm(self) -> None: async def test_delete_requires_confirm(self) -> None:
tool_fn = _make_tool() tool_fn = _make_tool()

View File

@@ -149,6 +149,15 @@ class TestSafeGet:
result = safe_get({}, "missing", default=[]) result = safe_get({}, "missing", default=[])
assert result == [] assert result == []
def test_zero_value_not_replaced_by_default(self) -> None:
assert safe_get({"temp": 0}, "temp", default="N/A") == 0
def test_false_value_not_replaced_by_default(self) -> None:
assert safe_get({"active": False}, "active", default=True) is False
def test_empty_string_not_replaced_by_default(self) -> None:
assert safe_get({"name": ""}, "name", default="unknown") == ""
class TestStorageActions: class TestStorageActions:
async def test_shares(self, _mock_graphql: AsyncMock) -> None: async def test_shares(self, _mock_graphql: AsyncMock) -> None:

View File

@@ -60,8 +60,8 @@ class TestCapLogContentSmallData:
class TestCapLogContentTruncation: class TestCapLogContentTruncation:
"""Content exceeding both byte AND line limits must be truncated to the last N lines.""" """Content exceeding both byte AND line limits must be truncated to the last N lines."""
def test_oversized_content_truncated_to_last_n_lines(self) -> None: def test_oversized_content_truncated_and_byte_capped(self) -> None:
# 200 lines, limit 50 lines, byte limit effectively 0 → should keep last 50 lines # 200 lines, tiny byte limit: must keep recent content within byte cap.
lines = [f"line {i}" for i in range(200)] lines = [f"line {i}" for i in range(200)]
data = {"content": "\n".join(lines)} data = {"content": "\n".join(lines)}
with ( with (
@@ -70,14 +70,13 @@ class TestCapLogContentTruncation:
): ):
result = _cap_log_content(data) result = _cap_log_content(data)
result_lines = result["content"].splitlines() result_lines = result["content"].splitlines()
assert len(result_lines) == 50 assert len(result["content"].encode("utf-8", errors="replace")) <= 10
# Must be the LAST 50 lines # Must keep the most recent line suffix.
assert result_lines[0] == "line 150"
assert result_lines[-1] == "line 199" assert result_lines[-1] == "line 199"
def test_content_with_fewer_lines_than_limit_not_truncated(self) -> None: def test_content_with_fewer_lines_than_limit_still_honors_byte_cap(self) -> None:
"""If byte limit exceeded but line count ≤ limit → keep original (not truncated).""" """If byte limit is exceeded, output must still be capped even with few lines."""
# 30 lines but byte limit 10 and line limit 50 → 30 < 50 so no truncation # 30 lines, byte limit 10, line limit 50 -> must cap bytes regardless of line count
lines = [f"line {i}" for i in range(30)] lines = [f"line {i}" for i in range(30)]
data = {"content": "\n".join(lines)} data = {"content": "\n".join(lines)}
with ( with (
@@ -85,8 +84,7 @@ class TestCapLogContentTruncation:
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
): ):
result = _cap_log_content(data) result = _cap_log_content(data)
# Original content preserved assert len(result["content"].encode("utf-8", errors="replace")) <= 10
assert result["content"] == data["content"]
def test_non_content_keys_preserved_alongside_truncated_content(self) -> None: def test_non_content_keys_preserved_alongside_truncated_content(self) -> None:
lines = [f"line {i}" for i in range(200)] lines = [f"line {i}" for i in range(200)]
@@ -98,7 +96,7 @@ class TestCapLogContentTruncation:
result = _cap_log_content(data) result = _cap_log_content(data)
assert result["path"] == "/var/log/syslog" assert result["path"] == "/var/log/syslog"
assert result["total_lines"] == 200 assert result["total_lines"] == 200
assert len(result["content"].splitlines()) == 50 assert len(result["content"].encode("utf-8", errors="replace")) <= 10
class TestCapLogContentNested: class TestCapLogContentNested:
@@ -112,7 +110,7 @@ class TestCapLogContentNested:
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
): ):
result = _cap_log_content(data) result = _cap_log_content(data)
assert len(result["logFile"]["content"].splitlines()) == 50 assert len(result["logFile"]["content"].encode("utf-8", errors="replace")) <= 10
assert result["logFile"]["path"] == "/var/log/syslog" assert result["logFile"]["path"] == "/var/log/syslog"
def test_deeply_nested_content_capped(self) -> None: def test_deeply_nested_content_capped(self) -> None:
@@ -123,9 +121,36 @@ class TestCapLogContentNested:
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
): ):
result = _cap_log_content(data) result = _cap_log_content(data)
assert len(result["outer"]["inner"]["content"].splitlines()) == 50 assert len(result["outer"]["inner"]["content"].encode("utf-8", errors="replace")) <= 10
def test_nested_non_content_keys_unaffected(self) -> None: def test_nested_non_content_keys_unaffected(self) -> None:
data = {"metrics": {"cpu": 42.5, "memory": 8192}} data = {"metrics": {"cpu": 42.5, "memory": 8192}}
result = _cap_log_content(data) result = _cap_log_content(data)
assert result == data assert result == data
class TestCapLogContentSingleMassiveLine:
"""A single line larger than the byte cap must be hard-capped at byte level."""
def test_single_massive_line_hard_caps_bytes(self) -> None:
# One line, no newlines, larger than the byte cap.
# The while-loop can't reduce it (len(lines) == 1), so the
# last-resort byte-slice path at manager.py:65-69 must fire.
huge_content = "x" * 200
data = {"content": huge_content}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
):
result = _cap_log_content(data)
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
def test_single_massive_line_input_not_mutated(self) -> None:
huge_content = "x" * 200
data = {"content": huge_content}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
):
_cap_log_content(data)
assert data["content"] == huge_content

View File

@@ -1,13 +1,6 @@
"""Unraid MCP Server Package. """Unraid MCP Server Package."""
A modular MCP (Model Context Protocol) server that provides tools to interact from .version import VERSION
with an Unraid server's GraphQL API.
"""
from importlib.metadata import PackageNotFoundError, version
try: __version__ = VERSION
__version__ = version("unraid-mcp")
except PackageNotFoundError:
__version__ = "0.0.0"

View File

@@ -47,7 +47,7 @@ class OverwriteFileHandler(logging.FileHandler):
"""Emit a record, checking file size periodically and overwriting if needed.""" """Emit a record, checking file size periodically and overwriting if needed."""
self._emit_count += 1 self._emit_count += 1
if ( if (
self._emit_count % self._check_interval == 0 (self._emit_count == 1 or self._emit_count % self._check_interval == 0)
and self.stream and self.stream
and hasattr(self.stream, "name") and hasattr(self.stream, "name")
): ):
@@ -249,5 +249,3 @@ if FASTMCP_AVAILABLE:
else: else:
# Fallback to our custom logger if FastMCP is not available # Fallback to our custom logger if FastMCP is not available
logger = setup_logger() logger = setup_logger()
# Also configure FastMCP logger for consistency
configure_fastmcp_logger_with_rich()

View File

@@ -5,12 +5,13 @@ and provides all configuration constants used throughout the application.
""" """
import os import os
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from dotenv import load_dotenv from dotenv import load_dotenv
from ..version import VERSION as APP_VERSION
# Get the script directory (config module location) # Get the script directory (config module location)
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/ SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
@@ -31,12 +32,6 @@ for dotenv_path in dotenv_paths:
load_dotenv(dotenv_path=dotenv_path) load_dotenv(dotenv_path=dotenv_path)
break break
# Application Version (single source of truth: pyproject.toml)
try:
VERSION = version("unraid-mcp")
except PackageNotFoundError:
VERSION = "0.0.0"
# Core API Configuration # Core API Configuration
UNRAID_API_URL = os.getenv("UNRAID_API_URL") UNRAID_API_URL = os.getenv("UNRAID_API_URL")
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY") UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
@@ -58,12 +53,18 @@ else: # Path to CA bundle
# Logging Configuration # Logging Configuration
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper() LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log") LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
# Use /app/logs in Docker, project-relative logs/ directory otherwise # Use /.dockerenv as the container indicator for robust Docker detection.
LOGS_DIR = Path("/app/logs") if Path("/app").is_dir() else PROJECT_ROOT / "logs" IS_DOCKER = Path("/.dockerenv").exists()
LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs"
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
# Ensure logs directory exists # Ensure logs directory exists; if creation fails, fall back to /tmp.
LOGS_DIR.mkdir(parents=True, exist_ok=True) try:
LOGS_DIR.mkdir(parents=True, exist_ok=True)
except OSError:
LOGS_DIR = PROJECT_ROOT / ".cache" / "logs"
LOGS_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
# HTTP Client Configuration # HTTP Client Configuration
TIMEOUT_CONFIG = { TIMEOUT_CONFIG = {
@@ -109,3 +110,5 @@ def get_config_summary() -> dict[str, Any]:
"config_valid": is_valid, "config_valid": is_valid,
"missing_config": missing if not is_valid else None, "missing_config": missing if not is_valid else None,
} }
# Re-export application version from a single source of truth.
VERSION = APP_VERSION

View File

@@ -7,6 +7,7 @@ to the Unraid API with proper timeout handling and error management.
import asyncio import asyncio
import hashlib import hashlib
import json import json
import re
import time import time
from typing import Any, Final from typing import Any, Final
@@ -47,14 +48,14 @@ def _is_sensitive_key(key: str) -> bool:
return any(s in key_lower for s in _SENSITIVE_KEYS) return any(s in key_lower for s in _SENSITIVE_KEYS)
def _redact_sensitive(obj: Any) -> Any: def redact_sensitive(obj: Any) -> Any:
"""Recursively redact sensitive values from nested dicts/lists.""" """Recursively redact sensitive values from nested dicts/lists."""
if isinstance(obj, dict): if isinstance(obj, dict):
return { return {
k: ("***" if _is_sensitive_key(k) else _redact_sensitive(v)) for k, v in obj.items() k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()
} }
if isinstance(obj, list): if isinstance(obj, list):
return [_redact_sensitive(item) for item in obj] return [redact_sensitive(item) for item in obj]
return obj return obj
@@ -139,6 +140,7 @@ _CACHEABLE_QUERY_PREFIXES = frozenset(
) )
_CACHE_TTL_SECONDS = 60.0 _CACHE_TTL_SECONDS = 60.0
_OPERATION_NAME_PATTERN = re.compile(r"^(?:query\s+)?([_A-Za-z][_0-9A-Za-z]*)\b")
class _QueryCache: class _QueryCache:
@@ -160,9 +162,13 @@ class _QueryCache:
@staticmethod @staticmethod
def is_cacheable(query: str) -> bool: def is_cacheable(query: str) -> bool:
"""Check if a query is eligible for caching based on its operation name.""" """Check if a query is eligible for caching based on its operation name."""
if query.lstrip().startswith("mutation"): normalized = query.lstrip()
if normalized.startswith("mutation"):
return False return False
return any(prefix in query for prefix in _CACHEABLE_QUERY_PREFIXES) match = _OPERATION_NAME_PATTERN.match(normalized)
if not match:
return False
return match.group(1) in _CACHEABLE_QUERY_PREFIXES
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None: def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
"""Return cached result if present and not expired, else None.""" """Return cached result if present and not expired, else None."""
@@ -324,7 +330,7 @@ async def make_graphql_request(
logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:") logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:")
logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query
if variables: if variables:
logger.debug(f"Variables: {_redact_sensitive(variables)}") logger.debug(f"Variables: {redact_sensitive(variables)}")
try: try:
# Rate limit: consume a token before making the request # Rate limit: consume a token before making the request

View File

@@ -45,13 +45,12 @@ def tool_error_handler(
except ToolError: except ToolError:
raise raise
except TimeoutError as e: except TimeoutError as e:
logger.error( logger.exception(f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit")
f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit",
exc_info=True,
)
raise ToolError( raise ToolError(
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time." f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
) from e ) from e
except Exception as e: except Exception as e:
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True) logger.exception(f"Error in unraid_{tool_name} action={action}")
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e raise ToolError(
f"Failed to execute {tool_name}/{action}. Check server logs for details."
) from e

View File

@@ -20,33 +20,10 @@ class SubscriptionData:
last_updated: datetime # Must be timezone-aware (UTC) last_updated: datetime # Must be timezone-aware (UTC)
subscription_type: str subscription_type: str
def __post_init__(self) -> None:
@dataclass(slots=True) if self.last_updated.tzinfo is None:
class SystemHealth: raise ValueError(
"""Container for system health status information. "last_updated must be timezone-aware; use datetime.now(UTC)"
)
Note: last_checked must be timezone-aware (use datetime.now(UTC)). if not self.subscription_type.strip():
""" raise ValueError("subscription_type must be a non-empty string")
is_healthy: bool
issues: list[str]
warnings: list[str]
last_checked: datetime # Must be timezone-aware (UTC)
component_status: dict[str, str]
@dataclass(slots=True)
class APIResponse:
"""Container for standardized API response data."""
success: bool
data: dict[str, Any] | None = None
error: str | None = None
metadata: dict[str, Any] | None = None
# Type aliases for common data structures
ConfigValue = str | int | bool | float | None
ConfigDict = dict[str, ConfigValue]
GraphQLVariables = dict[str, Any]
HealthStatus = dict[str, str | bool | int | list[Any]]

View File

@@ -1,6 +1,7 @@
"""Shared utility functions for Unraid MCP tools.""" """Shared utility functions for Unraid MCP tools."""
from typing import Any from typing import Any
from urllib.parse import urlparse
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any: def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
@@ -45,6 +46,25 @@ def format_bytes(bytes_value: int | None) -> str:
return f"{value:.2f} EB" return f"{value:.2f} EB"
def safe_display_url(url: str | None) -> str | None:
"""Return a redacted URL showing only scheme + host + port.
Strips path, query parameters, credentials, and fragments to avoid
leaking internal network topology or embedded secrets (CWE-200).
"""
if not url:
return None
try:
parsed = urlparse(url)
host = parsed.hostname or "unknown"
if parsed.port:
return f"{parsed.scheme}://{host}:{parsed.port}"
return f"{parsed.scheme}://{host}"
except ValueError:
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
return "<unparseable>"
def format_kb(k: Any) -> str: def format_kb(k: Any) -> str:
"""Format kilobyte values into human-readable sizes. """Format kilobyte values into human-readable sizes.

View File

@@ -19,6 +19,7 @@ from websockets.typing import Subprotocol
from ..config.logging import logger from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.exceptions import ToolError from ..core.exceptions import ToolError
from ..core.utils import safe_display_url
from .manager import subscription_manager from .manager import subscription_manager
from .resources import ensure_subscriptions_started from .resources import ensure_subscriptions_started
from .utils import build_ws_ssl_context, build_ws_url from .utils import build_ws_ssl_context, build_ws_url
@@ -162,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"note": "Connection successful, subscription may be waiting for events", "note": "Connection successful, subscription may be waiting for events",
} }
except ToolError:
raise
except Exception as e: except Exception as e:
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True) logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
return {"error": str(e), "query_tested": subscription_query} return {"error": str(e), "query_tested": subscription_query}
@@ -193,7 +196,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"environment": { "environment": {
"auto_start_enabled": subscription_manager.auto_start_enabled, "auto_start_enabled": subscription_manager.auto_start_enabled,
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts, "max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
"unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None, "unraid_api_url": safe_display_url(UNRAID_API_URL),
"api_key_configured": bool(UNRAID_API_KEY), "api_key_configured": bool(UNRAID_API_KEY),
"websocket_url": None, "websocket_url": None,
}, },

View File

@@ -17,7 +17,7 @@ from websockets.typing import Subprotocol
from ..config.logging import logger from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY from ..config.settings import UNRAID_API_KEY
from ..core.client import _redact_sensitive from ..core.client import redact_sensitive
from ..core.types import SubscriptionData from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context, build_ws_url from .utils import build_ws_ssl_context, build_ws_url
@@ -36,8 +36,7 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
field (from log subscriptions) exceeds the byte limit, truncate it to the field (from log subscriptions) exceeds the byte limit, truncate it to the
most recent _MAX_RESOURCE_DATA_LINES lines. most recent _MAX_RESOURCE_DATA_LINES lines.
Note: single lines larger than _MAX_RESOURCE_DATA_BYTES are not split and The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
will still be stored at full size; only multi-line content is truncated.
""" """
result: dict[str, Any] = {} result: dict[str, Any] = {}
for key, value in data.items(): for key, value in data.items():
@@ -49,17 +48,33 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
): ):
lines = value.splitlines() lines = value.splitlines()
original_line_count = len(lines)
# Keep most recent lines first.
if len(lines) > _MAX_RESOURCE_DATA_LINES: if len(lines) > _MAX_RESOURCE_DATA_LINES:
truncated = "\n".join(lines[-_MAX_RESOURCE_DATA_LINES:]) lines = lines[-_MAX_RESOURCE_DATA_LINES:]
# Enforce byte cap while preserving whole-line boundaries where possible.
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
lines = lines[1:]
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
# Last resort: if a single line still exceeds cap, hard-cap bytes.
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
"utf-8", errors="ignore"
)
logger.warning( logger.warning(
f"[RESOURCE] Capped log content from {len(lines)} to " f"[RESOURCE] Capped log content from {original_line_count} to "
f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)" f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
) )
result[key] = truncated result[key] = truncated
else: else:
result[key] = value result[key] = value
else:
result[key] = value
return result return result
@@ -148,6 +163,7 @@ class SubscriptionManager:
# Reset connection tracking # Reset connection tracking
self.reconnect_attempts[subscription_name] = 0 self.reconnect_attempts[subscription_name] = 0
self.connection_states[subscription_name] = "starting" self.connection_states[subscription_name] = "starting"
self._connection_start_times.pop(subscription_name, None)
async with self.subscription_lock: async with self.subscription_lock:
try: try:
@@ -181,6 +197,7 @@ class SubscriptionManager:
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully") logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
del self.active_subscriptions[subscription_name] del self.active_subscriptions[subscription_name]
self.connection_states[subscription_name] = "stopped" self.connection_states[subscription_name] = "stopped"
self._connection_start_times.pop(subscription_name, None)
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped") logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
else: else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop") logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
@@ -322,7 +339,7 @@ class SubscriptionManager:
) )
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...") logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
logger.debug( logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Variables: {_redact_sensitive(variables)}" f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}"
) )
await websocket.send(json.dumps(subscription_message)) await websocket.send(json.dumps(subscription_message))
@@ -431,7 +448,8 @@ class SubscriptionManager:
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}") logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[DATA:{subscription_name}] Error processing message: {e}" f"[DATA:{subscription_name}] Error processing message: {e}",
exc_info=True,
) )
msg_preview = ( msg_preview = (
message[:200] message[:200]
@@ -461,14 +479,22 @@ class SubscriptionManager:
self.connection_states[subscription_name] = "invalid_uri" self.connection_states[subscription_name] = "invalid_uri"
break # Don't retry on invalid URI break # Don't retry on invalid URI
except ValueError as e:
# Non-retryable configuration error (e.g. UNRAID_API_URL not set)
error_msg = f"Configuration error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
break # Don't retry on configuration errors
except Exception as e: except Exception as e:
error_msg = f"Unexpected error: {e}" error_msg = f"Unexpected error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}") logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True)
self.last_error[subscription_name] = error_msg self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error" self.connection_states[subscription_name] = "error"
# Check if connection was stable before deciding on retry behavior # Check if connection was stable before deciding on retry behavior
start_time = self._connection_start_times.get(subscription_name) start_time = self._connection_start_times.pop(subscription_name, None)
if start_time is not None: if start_time is not None:
connected_duration = time.monotonic() - start_time connected_duration = time.monotonic() - start_time
if connected_duration >= _STABLE_CONNECTION_SECONDS: if connected_duration >= _STABLE_CONNECTION_SECONDS:

View File

@@ -44,6 +44,7 @@ async def autostart_subscriptions() -> None:
logger.info("[AUTOSTART] Auto-start process completed successfully") logger.info("[AUTOSTART] Auto-start process completed successfully")
except Exception as e: except Exception as e:
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True) logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
raise # Propagate so ensure_subscriptions_started doesn't mark as started
# Optional log file subscription # Optional log file subscription
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH") log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")

View File

@@ -3,7 +3,7 @@
Provides the `unraid_array` tool with 5 actions for parity check management. Provides the `unraid_array` tool with 5 actions for parity check management.
""" """
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -53,6 +53,14 @@ ARRAY_ACTIONS = Literal[
"parity_status", "parity_status",
] ]
if set(get_args(ARRAY_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(ARRAY_ACTIONS))
_extra = set(get_args(ARRAY_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"ARRAY_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_array_tool(mcp: FastMCP) -> None: def register_array_tool(mcp: FastMCP) -> None:
"""Register the unraid_array tool with the FastMCP instance.""" """Register the unraid_array tool with the FastMCP instance."""

View File

@@ -5,7 +5,7 @@ logs, networks, and update management.
""" """
import re import re
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -135,6 +135,14 @@ DOCKER_ACTIONS = Literal[
"check_updates", "check_updates",
] ]
if set(get_args(DOCKER_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(DOCKER_ACTIONS))
_extra = set(get_args(DOCKER_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"DOCKER_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
# Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local") # Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local")
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE) _DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
@@ -199,11 +207,6 @@ def get_available_container_names(containers: list[dict[str, Any]]) -> list[str]
return names return names
def _looks_like_container_id(identifier: str) -> bool:
"""Check if an identifier looks like a container ID (full or short hex prefix)."""
return bool(_DOCKER_ID_PATTERN.match(identifier) or _DOCKER_SHORT_ID_PATTERN.match(identifier))
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str: async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
"""Resolve a container name/identifier to its actual PrefixedID. """Resolve a container name/identifier to its actual PrefixedID.
@@ -233,12 +236,21 @@ async def _resolve_container_id(container_id: str, *, strict: bool = False) -> s
# Short hex prefix: match by ID prefix before trying name matching # Short hex prefix: match by ID prefix before trying name matching
if _DOCKER_SHORT_ID_PATTERN.match(container_id): if _DOCKER_SHORT_ID_PATTERN.match(container_id):
id_lower = container_id.lower() id_lower = container_id.lower()
matches: list[dict[str, Any]] = []
for c in containers: for c in containers:
cid = (c.get("id") or "").lower() cid = (c.get("id") or "").lower()
if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower): if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower):
actual_id = str(c.get("id", "")) matches.append(c)
if len(matches) == 1:
actual_id = str(matches[0].get("id", ""))
logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'") logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'")
return actual_id return actual_id
if len(matches) > 1:
candidate_ids = [str(c.get("id", "")) for c in matches[:5]]
raise ToolError(
f"Short container ID prefix '{container_id}' is ambiguous. "
f"Matches: {', '.join(candidate_ids)}. Use a longer ID or exact name."
)
resolved = find_container_by_identifier(container_id, containers, strict=strict) resolved = find_container_by_identifier(container_id, containers, strict=strict)
if resolved: if resolved:
@@ -303,7 +315,7 @@ def register_docker_tool(mcp: FastMCP) -> None:
if action == "network_details" and not network_id: if action == "network_details" and not network_id:
raise ToolError("network_id is required for 'network_details' action") raise ToolError("network_id is required for 'network_details' action")
if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES: if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}") raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
with tool_error_handler("docker", action, logger): with tool_error_handler("docker", action, logger):
@@ -335,12 +347,12 @@ def register_docker_tool(mcp: FastMCP) -> None:
if action == "networks": if action == "networks":
data = await make_graphql_request(QUERIES["networks"]) data = await make_graphql_request(QUERIES["networks"])
networks = data.get("dockerNetworks", []) networks = safe_get(data, "dockerNetworks", default=[])
return {"networks": networks} return {"networks": networks}
if action == "network_details": if action == "network_details":
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id}) data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
return dict(data.get("dockerNetwork") or {}) return dict(safe_get(data, "dockerNetwork", default={}) or {})
if action == "port_conflicts": if action == "port_conflicts":
data = await make_graphql_request(QUERIES["port_conflicts"]) data = await make_graphql_request(QUERIES["port_conflicts"])

View File

@@ -6,8 +6,7 @@ connection testing, and subscription diagnostics.
import datetime import datetime
import time import time
from typing import Any, Literal from typing import Any, Literal, get_args
from urllib.parse import urlparse
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -21,31 +20,21 @@ from ..config.settings import (
) )
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError, tool_error_handler
from ..core.utils import safe_display_url
def _safe_display_url(url: str | None) -> str | None:
"""Return a redacted URL showing only scheme + host + port.
Strips path, query parameters, credentials, and fragments to avoid
leaking internal network topology or embedded secrets (CWE-200).
"""
if not url:
return None
try:
parsed = urlparse(url)
host = parsed.hostname or "unknown"
if parsed.port:
return f"{parsed.scheme}://{host}:{parsed.port}"
return f"{parsed.scheme}://{host}"
except ValueError:
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
return "<unparseable>"
ALL_ACTIONS = {"check", "test_connection", "diagnose"} ALL_ACTIONS = {"check", "test_connection", "diagnose"}
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"] HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
if set(get_args(HEALTH_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(HEALTH_ACTIONS))
_extra = set(get_args(HEALTH_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
"HEALTH_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing in HEALTH_ACTIONS: {_missing}; extra in HEALTH_ACTIONS: {_extra}"
)
# Severity ordering: only upgrade, never downgrade # Severity ordering: only upgrade, never downgrade
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3} _SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
@@ -149,7 +138,7 @@ async def _comprehensive_check() -> dict[str, Any]:
if info: if info:
health_info["unraid_system"] = { health_info["unraid_system"] = {
"status": "connected", "status": "connected",
"url": _safe_display_url(UNRAID_API_URL), "url": safe_display_url(UNRAID_API_URL),
"machine_id": info.get("machineId"), "machine_id": info.get("machineId"),
"version": info.get("versions", {}).get("unraid"), "version": info.get("versions", {}).get("unraid"),
"uptime": info.get("os", {}).get("uptime"), "uptime": info.get("os", {}).get("uptime"),
@@ -220,7 +209,7 @@ async def _comprehensive_check() -> dict[str, Any]:
except Exception as e: except Exception as e:
# Intentionally broad: health checks must always return a result, # Intentionally broad: health checks must always return a result,
# even on unexpected failures, so callers never get an unhandled exception. # even on unexpected failures, so callers never get an unhandled exception.
logger.error(f"Health check failed: {e}") logger.error(f"Health check failed: {e}", exc_info=True)
return { return {
"status": "unhealthy", "status": "unhealthy",
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(), "timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
@@ -293,10 +282,7 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
}, },
} }
except ImportError: except ImportError as e:
return { raise ToolError("Subscription modules not available") from e
"error": "Subscription modules not available",
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
}
except Exception as e: except Exception as e:
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e raise ToolError(f"Failed to generate diagnostics: {e!s}") from e

View File

@@ -4,7 +4,7 @@ Provides the `unraid_info` tool with 19 read-only actions for retrieving
system information, array status, network config, and server metadata. system information, array status, network config, and server metadata.
""" """
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -180,9 +180,9 @@ INFO_ACTIONS = Literal[
"ups_config", "ups_config",
] ]
if set(INFO_ACTIONS.__args__) != ALL_ACTIONS: if set(get_args(INFO_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(INFO_ACTIONS.__args__) _missing = ALL_ACTIONS - set(get_args(INFO_ACTIONS))
_extra = set(INFO_ACTIONS.__args__) - ALL_ACTIONS _extra = set(get_args(INFO_ACTIONS)) - ALL_ACTIONS
raise RuntimeError( raise RuntimeError(
f"QUERIES keys and INFO_ACTIONS are out of sync. " f"QUERIES keys and INFO_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
@@ -415,7 +415,8 @@ def register_info_tool(mcp: FastMCP) -> None:
if action in list_actions: if action in list_actions:
response_key, output_key = list_actions[action] response_key, output_key = list_actions[action]
items = data.get(response_key) or [] items = data.get(response_key) or []
return {output_key: items} normalized_items = list(items) if isinstance(items, list) else []
return {output_key: normalized_items}
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")

View File

@@ -4,7 +4,7 @@ Provides the `unraid_keys` tool with 5 actions for listing, viewing,
creating, updating, and deleting API keys. creating, updating, and deleting API keys.
""" """
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -55,6 +55,14 @@ KEY_ACTIONS = Literal[
"delete", "delete",
] ]
if set(get_args(KEY_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(KEY_ACTIONS))
_extra = set(get_args(KEY_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"KEY_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_keys_tool(mcp: FastMCP) -> None: def register_keys_tool(mcp: FastMCP) -> None:
"""Register the unraid_keys tool with the FastMCP instance.""" """Register the unraid_keys tool with the FastMCP instance."""

View File

@@ -4,7 +4,7 @@ Provides the `unraid_notifications` tool with 9 actions for viewing,
creating, archiving, and deleting system notifications. creating, archiving, and deleting system notifications.
""" """
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -91,6 +91,14 @@ NOTIFICATION_ACTIONS = Literal[
"archive_all", "archive_all",
] ]
if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(NOTIFICATION_ACTIONS))
_extra = set(get_args(NOTIFICATION_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"NOTIFICATION_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_notifications_tool(mcp: FastMCP) -> None: def register_notifications_tool(mcp: FastMCP) -> None:
"""Register the unraid_notifications tool with the FastMCP instance.""" """Register the unraid_notifications tool with the FastMCP instance."""

View File

@@ -5,7 +5,7 @@ cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.).
""" """
import re import re
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -50,10 +50,18 @@ RCLONE_ACTIONS = Literal[
"delete_remote", "delete_remote",
] ]
if set(get_args(RCLONE_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(RCLONE_ACTIONS))
_extra = set(get_args(RCLONE_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"RCLONE_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
# Max config entries to prevent abuse # Max config entries to prevent abuse
_MAX_CONFIG_KEYS = 50 _MAX_CONFIG_KEYS = 50
# Pattern for suspicious key names (path traversal, shell metacharacters) # Pattern for suspicious key names (path traversal, shell metacharacters)
_DANGEROUS_KEY_PATTERN = re.compile(r"[.]{2}|[/\\;|`$(){}]") _DANGEROUS_KEY_PATTERN = re.compile(r"\.\.|[/\\;|`$(){}]")
# Max length for individual config values # Max length for individual config values
_MAX_VALUE_LENGTH = 4096 _MAX_VALUE_LENGTH = 4096

View File

@@ -5,7 +5,7 @@ unassigned devices, log files, and log content retrieval.
""" """
import os import os
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -69,6 +69,14 @@ STORAGE_ACTIONS = Literal[
"logs", "logs",
] ]
if set(get_args(STORAGE_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(STORAGE_ACTIONS))
_extra = set(get_args(STORAGE_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"STORAGE_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_storage_tool(mcp: FastMCP) -> None: def register_storage_tool(mcp: FastMCP) -> None:
"""Register the unraid_storage tool with the FastMCP instance.""" """Register the unraid_storage tool with the FastMCP instance."""
@@ -96,7 +104,7 @@ def register_storage_tool(mcp: FastMCP) -> None:
if action == "disk_details" and not disk_id: if action == "disk_details" and not disk_id:
raise ToolError("disk_id is required for 'disk_details' action") raise ToolError("disk_id is required for 'disk_details' action")
if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES: if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}") raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
if action == "logs": if action == "logs":

View File

@@ -4,7 +4,7 @@ Provides the `unraid_vm` tool with 9 actions for VM lifecycle management
including start, stop, pause, resume, force stop, reboot, and reset. including start, stop, pause, resume, force stop, reboot, and reset.
""" """
from typing import Any, Literal from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -73,6 +73,14 @@ VM_ACTIONS = Literal[
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
if set(get_args(VM_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(VM_ACTIONS))
_extra = set(get_args(VM_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"VM_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_vm_tool(mcp: FastMCP) -> None: def register_vm_tool(mcp: FastMCP) -> None:
"""Register the unraid_vm tool with the FastMCP instance.""" """Register the unraid_vm tool with the FastMCP instance."""

11
unraid_mcp/version.py Normal file
View File

@@ -0,0 +1,11 @@
"""Application version helpers."""
from importlib.metadata import PackageNotFoundError, version
__all__ = ["VERSION"]
try:
VERSION = version("unraid-mcp")
except PackageNotFoundError:
VERSION = "0.0.0"