mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-02 00:04:45 -08:00
fix: apply all PR review agent findings (silent failures, type safety, test gaps)
Addresses issues found by 4 parallel review agents (code-reviewer,
silent-failure-hunter, type-design-analyzer, pr-test-analyzer).
Source fixes:
- core/utils.py: add public safe_display_url() (moved from tools/health.py)
- core/client.py: rename _redact_sensitive → redact_sensitive (public API)
- core/types.py: add SubscriptionData.__post_init__ for tz-aware datetime
enforcement; remove 6 unused type aliases (SystemHealth, APIResponse, etc.)
- subscriptions/manager.py: add exc_info=True to both except-Exception blocks;
add except ValueError break-on-config-error before retry loop; import
redact_sensitive by new public name
- subscriptions/resources.py: re-raise in autostart_subscriptions() so
ensure_subscriptions_started() doesn't permanently set _subscriptions_started
- subscriptions/diagnostics.py: except ToolError: raise before broad except;
use safe_display_url() instead of raw URL slice
- tools/health.py: move _safe_display_url to core/utils; add exc_info=True;
raise ToolError (not return dict) on ImportError
- tools/info.py: use get_args(INFO_ACTIONS) instead of INFO_ACTIONS.__args__
- tools/{array,docker,keys,notifications,rclone,storage,virtualization}.py:
add Literal-vs-ALL_ACTIONS sync check at import time
Test fixes:
- test_health.py: import safe_display_url from core.utils; update
test_diagnose_import_error_internal to expect ToolError (not error dict)
- test_storage.py: add 3 safe_get tests for zero/False/empty-string values
- test_subscription_manager.py: add TestCapLogContentSingleMassiveLine (2 tests)
- test_client.py: rename _redact_sensitive → redact_sensitive; add tests for
new sensitive keys and is_cacheable explicit-keyword form
This commit is contained in:
@@ -16,6 +16,7 @@ import websockets.exceptions
|
|||||||
|
|
||||||
from unraid_mcp.subscriptions.manager import SubscriptionManager
|
from unraid_mcp.subscriptions.manager import SubscriptionManager
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ from unraid_mcp.core.client import (
|
|||||||
DISK_TIMEOUT,
|
DISK_TIMEOUT,
|
||||||
_QueryCache,
|
_QueryCache,
|
||||||
_RateLimiter,
|
_RateLimiter,
|
||||||
_redact_sensitive,
|
|
||||||
is_idempotent_error,
|
is_idempotent_error,
|
||||||
make_graphql_request,
|
make_graphql_request,
|
||||||
|
redact_sensitive,
|
||||||
)
|
)
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class TestIsIdempotentError:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _redact_sensitive
|
# redact_sensitive
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -69,36 +69,36 @@ class TestRedactSensitive:
|
|||||||
|
|
||||||
def test_flat_dict(self) -> None:
|
def test_flat_dict(self) -> None:
|
||||||
data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"}
|
data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["username"] == "admin"
|
assert result["username"] == "admin"
|
||||||
assert result["password"] == "***"
|
assert result["password"] == "***"
|
||||||
assert result["host"] == "10.0.0.1"
|
assert result["host"] == "10.0.0.1"
|
||||||
|
|
||||||
def test_nested_dict(self) -> None:
|
def test_nested_dict(self) -> None:
|
||||||
data = {"config": {"apiKey": "abc123", "url": "http://host"}}
|
data = {"config": {"apiKey": "abc123", "url": "http://host"}}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["config"]["apiKey"] == "***"
|
assert result["config"]["apiKey"] == "***"
|
||||||
assert result["config"]["url"] == "http://host"
|
assert result["config"]["url"] == "http://host"
|
||||||
|
|
||||||
def test_list_of_dicts(self) -> None:
|
def test_list_of_dicts(self) -> None:
|
||||||
data = [{"token": "t1"}, {"name": "safe"}]
|
data = [{"token": "t1"}, {"name": "safe"}]
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result[0]["token"] == "***"
|
assert result[0]["token"] == "***"
|
||||||
assert result[1]["name"] == "safe"
|
assert result[1]["name"] == "safe"
|
||||||
|
|
||||||
def test_deeply_nested(self) -> None:
|
def test_deeply_nested(self) -> None:
|
||||||
data = {"a": {"b": {"c": {"secret": "deep"}}}}
|
data = {"a": {"b": {"c": {"secret": "deep"}}}}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["a"]["b"]["c"]["secret"] == "***"
|
assert result["a"]["b"]["c"]["secret"] == "***"
|
||||||
|
|
||||||
def test_non_dict_passthrough(self) -> None:
|
def test_non_dict_passthrough(self) -> None:
|
||||||
assert _redact_sensitive("plain_string") == "plain_string"
|
assert redact_sensitive("plain_string") == "plain_string"
|
||||||
assert _redact_sensitive(42) == 42
|
assert redact_sensitive(42) == 42
|
||||||
assert _redact_sensitive(None) is None
|
assert redact_sensitive(None) is None
|
||||||
|
|
||||||
def test_case_insensitive_keys(self) -> None:
|
def test_case_insensitive_keys(self) -> None:
|
||||||
data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"}
|
data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
for v in result.values():
|
for v in result.values():
|
||||||
assert v == "***"
|
assert v == "***"
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ class TestRedactSensitive:
|
|||||||
"username": "safe",
|
"username": "safe",
|
||||||
"host": "safe",
|
"host": "safe",
|
||||||
}
|
}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["user_password"] == "***"
|
assert result["user_password"] == "***"
|
||||||
assert result["api_key_value"] == "***"
|
assert result["api_key_value"] == "***"
|
||||||
assert result["auth_token_expiry"] == "***"
|
assert result["auth_token_expiry"] == "***"
|
||||||
@@ -122,12 +122,26 @@ class TestRedactSensitive:
|
|||||||
|
|
||||||
def test_mixed_list_content(self) -> None:
|
def test_mixed_list_content(self) -> None:
|
||||||
data = [{"key": "val"}, "string", 123, [{"token": "inner"}]]
|
data = [{"key": "val"}, "string", 123, [{"token": "inner"}]]
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result[0]["key"] == "***"
|
assert result[0]["key"] == "***"
|
||||||
assert result[1] == "string"
|
assert result[1] == "string"
|
||||||
assert result[2] == 123
|
assert result[2] == 123
|
||||||
assert result[3][0]["token"] == "***"
|
assert result[3][0]["token"] == "***"
|
||||||
|
|
||||||
|
def test_new_sensitive_keys_are_redacted(self) -> None:
|
||||||
|
"""PR-added keys: authorization, cookie, session, credential, passphrase, jwt."""
|
||||||
|
data = {
|
||||||
|
"authorization": "Bearer token123",
|
||||||
|
"cookie": "session=abc",
|
||||||
|
"jwt": "eyJ...",
|
||||||
|
"credential": "secret_cred",
|
||||||
|
"passphrase": "hunter2",
|
||||||
|
"session": "sess_id",
|
||||||
|
}
|
||||||
|
result = redact_sensitive(data)
|
||||||
|
for key, val in result.items():
|
||||||
|
assert val == "***", f"Key '{key}' was not redacted"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Timeout constants
|
# Timeout constants
|
||||||
@@ -347,7 +361,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="invalid response.*not valid JSON"),
|
pytest.raises(ToolError, match=r"invalid response.*not valid JSON"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -481,7 +495,7 @@ class TestRateLimiter:
|
|||||||
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
|
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
|
||||||
initial = limiter.tokens
|
initial = limiter.tokens
|
||||||
await limiter.acquire()
|
await limiter.acquire()
|
||||||
assert limiter.tokens == initial - 1
|
assert limiter.tokens == pytest.approx(initial - 1, abs=1e-3)
|
||||||
|
|
||||||
async def test_acquire_succeeds_when_tokens_available(self) -> None:
|
async def test_acquire_succeeds_when_tokens_available(self) -> None:
|
||||||
limiter = _RateLimiter(max_tokens=5, refill_rate=1.0)
|
limiter = _RateLimiter(max_tokens=5, refill_rate=1.0)
|
||||||
@@ -596,6 +610,15 @@ class TestQueryCache:
|
|||||||
"""Queries that start with 'mutation' after whitespace are not cacheable."""
|
"""Queries that start with 'mutation' after whitespace are not cacheable."""
|
||||||
assert _QueryCache.is_cacheable(" mutation { ... }") is False
|
assert _QueryCache.is_cacheable(" mutation { ... }") is False
|
||||||
|
|
||||||
|
def test_is_cacheable_with_explicit_query_keyword(self) -> None:
|
||||||
|
"""Operation names after explicit 'query' keyword must be recognized."""
|
||||||
|
assert _QueryCache.is_cacheable("query GetNetworkConfig { network { name } }") is True
|
||||||
|
assert _QueryCache.is_cacheable("query GetOwner { owner { name } }") is True
|
||||||
|
|
||||||
|
def test_is_cacheable_anonymous_query_returns_false(self) -> None:
|
||||||
|
"""Anonymous 'query { ... }' has no operation name — must not be cached."""
|
||||||
|
assert _QueryCache.is_cacheable("query { network { name } }") is False
|
||||||
|
|
||||||
def test_expired_entry_removed_from_store(self) -> None:
|
def test_expired_entry_removed_from_store(self) -> None:
|
||||||
"""Accessing an expired entry should remove it from the internal store."""
|
"""Accessing an expired entry should remove it from the internal store."""
|
||||||
cache = _QueryCache()
|
cache = _QueryCache()
|
||||||
|
|||||||
@@ -80,6 +80,14 @@ class TestDockerValidation:
|
|||||||
with pytest.raises(ToolError, match="network_id"):
|
with pytest.raises(ToolError, match="network_id"):
|
||||||
await tool_fn(action="network_details")
|
await tool_fn(action="network_details")
|
||||||
|
|
||||||
|
async def test_non_logs_action_ignores_tail_lines_validation(
|
||||||
|
self, _mock_graphql: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
_mock_graphql.return_value = {"docker": {"containers": []}}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(action="list", tail_lines=0)
|
||||||
|
assert result["containers"] == []
|
||||||
|
|
||||||
|
|
||||||
class TestDockerActions:
|
class TestDockerActions:
|
||||||
async def test_list(self, _mock_graphql: AsyncMock) -> None:
|
async def test_list(self, _mock_graphql: AsyncMock) -> None:
|
||||||
@@ -224,9 +232,22 @@ class TestDockerActions:
|
|||||||
async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("unexpected failure")
|
_mock_graphql.side_effect = RuntimeError("unexpected failure")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="unexpected failure"):
|
with pytest.raises(ToolError, match="Failed to execute docker/list"):
|
||||||
await tool_fn(action="list")
|
await tool_fn(action="list")
|
||||||
|
|
||||||
|
async def test_short_id_prefix_ambiguous_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"docker": {
|
||||||
|
"containers": [
|
||||||
|
{"id": "abcdef1234560000000000000000000000000000000000000000000000000000:local", "names": ["plex"]},
|
||||||
|
{"id": "abcdef1234561111111111111111111111111111111111111111111111111111:local", "names": ["sonarr"]},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="ambiguous"):
|
||||||
|
await tool_fn(action="logs", container_id="abcdef123456")
|
||||||
|
|
||||||
|
|
||||||
class TestDockerMutationFailures:
|
class TestDockerMutationFailures:
|
||||||
"""Tests for mutation responses that indicate failure or unexpected shapes."""
|
"""Tests for mutation responses that indicate failure or unexpected shapes."""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
from conftest import make_tool_fn
|
from conftest import make_tool_fn
|
||||||
|
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
from unraid_mcp.tools.health import _safe_display_url
|
from unraid_mcp.core.utils import safe_display_url
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -100,7 +100,7 @@ class TestHealthActions:
|
|||||||
"unraid_mcp.tools.health._diagnose_subscriptions",
|
"unraid_mcp.tools.health._diagnose_subscriptions",
|
||||||
side_effect=RuntimeError("broken"),
|
side_effect=RuntimeError("broken"),
|
||||||
),
|
),
|
||||||
pytest.raises(ToolError, match="broken"),
|
pytest.raises(ToolError, match="Failed to execute health/diagnose"),
|
||||||
):
|
):
|
||||||
await tool_fn(action="diagnose")
|
await tool_fn(action="diagnose")
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ class TestHealthActions:
|
|||||||
assert "cpu_sub" in result
|
assert "cpu_sub" in result
|
||||||
|
|
||||||
async def test_diagnose_import_error_internal(self) -> None:
|
async def test_diagnose_import_error_internal(self) -> None:
|
||||||
"""_diagnose_subscriptions catches ImportError and returns error dict."""
|
"""_diagnose_subscriptions raises ToolError when subscription modules are unavailable."""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from unraid_mcp.tools.health import _diagnose_subscriptions
|
from unraid_mcp.tools.health import _diagnose_subscriptions
|
||||||
@@ -127,16 +127,18 @@ class TestHealthActions:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Replace the modules with objects that raise ImportError on access
|
# Replace the modules with objects that raise ImportError on access
|
||||||
with patch.dict(
|
with (
|
||||||
sys.modules,
|
patch.dict(
|
||||||
{
|
sys.modules,
|
||||||
"unraid_mcp.subscriptions": None,
|
{
|
||||||
"unraid_mcp.subscriptions.manager": None,
|
"unraid_mcp.subscriptions": None,
|
||||||
"unraid_mcp.subscriptions.resources": None,
|
"unraid_mcp.subscriptions.manager": None,
|
||||||
},
|
"unraid_mcp.subscriptions.resources": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
pytest.raises(ToolError, match="Subscription modules not available"),
|
||||||
):
|
):
|
||||||
result = await _diagnose_subscriptions()
|
await _diagnose_subscriptions()
|
||||||
assert "error" in result
|
|
||||||
finally:
|
finally:
|
||||||
# Restore cached modules
|
# Restore cached modules
|
||||||
sys.modules.update(cached)
|
sys.modules.update(cached)
|
||||||
@@ -148,47 +150,47 @@ class TestHealthActions:
|
|||||||
|
|
||||||
|
|
||||||
class TestSafeDisplayUrl:
|
class TestSafeDisplayUrl:
|
||||||
"""Verify that _safe_display_url strips credentials/path and preserves scheme+host+port."""
|
"""Verify that safe_display_url strips credentials/path and preserves scheme+host+port."""
|
||||||
|
|
||||||
def test_none_returns_none(self) -> None:
|
def test_none_returns_none(self) -> None:
|
||||||
assert _safe_display_url(None) is None
|
assert safe_display_url(None) is None
|
||||||
|
|
||||||
def test_empty_string_returns_none(self) -> None:
|
def test_empty_string_returns_none(self) -> None:
|
||||||
assert _safe_display_url("") is None
|
assert safe_display_url("") is None
|
||||||
|
|
||||||
def test_simple_url_scheme_and_host(self) -> None:
|
def test_simple_url_scheme_and_host(self) -> None:
|
||||||
assert _safe_display_url("https://unraid.local/graphql") == "https://unraid.local"
|
assert safe_display_url("https://unraid.local/graphql") == "https://unraid.local"
|
||||||
|
|
||||||
def test_preserves_port(self) -> None:
|
def test_preserves_port(self) -> None:
|
||||||
assert _safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337"
|
assert safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337"
|
||||||
|
|
||||||
def test_strips_path(self) -> None:
|
def test_strips_path(self) -> None:
|
||||||
result = _safe_display_url("http://unraid.local/some/deep/path?query=1")
|
result = safe_display_url("http://unraid.local/some/deep/path?query=1")
|
||||||
assert "path" not in result
|
assert "path" not in result
|
||||||
assert "query" not in result
|
assert "query" not in result
|
||||||
|
|
||||||
def test_strips_credentials(self) -> None:
|
def test_strips_credentials(self) -> None:
|
||||||
result = _safe_display_url("https://user:password@unraid.local/graphql")
|
result = safe_display_url("https://user:password@unraid.local/graphql")
|
||||||
assert "user" not in result
|
assert "user" not in result
|
||||||
assert "password" not in result
|
assert "password" not in result
|
||||||
assert result == "https://unraid.local"
|
assert result == "https://unraid.local"
|
||||||
|
|
||||||
def test_strips_query_params(self) -> None:
|
def test_strips_query_params(self) -> None:
|
||||||
result = _safe_display_url("http://host.local?token=abc&key=xyz")
|
result = safe_display_url("http://host.local?token=abc&key=xyz")
|
||||||
assert "token" not in result
|
assert "token" not in result
|
||||||
assert "abc" not in result
|
assert "abc" not in result
|
||||||
|
|
||||||
def test_http_scheme_preserved(self) -> None:
|
def test_http_scheme_preserved(self) -> None:
|
||||||
result = _safe_display_url("http://10.0.0.1:8080/api")
|
result = safe_display_url("http://10.0.0.1:8080/api")
|
||||||
assert result == "http://10.0.0.1:8080"
|
assert result == "http://10.0.0.1:8080"
|
||||||
|
|
||||||
def test_tailscale_url(self) -> None:
|
def test_tailscale_url(self) -> None:
|
||||||
result = _safe_display_url("https://100.118.209.1:31337/graphql")
|
result = safe_display_url("https://100.118.209.1:31337/graphql")
|
||||||
assert result == "https://100.118.209.1:31337"
|
assert result == "https://100.118.209.1:31337"
|
||||||
|
|
||||||
def test_malformed_ipv6_url_returns_unparseable(self) -> None:
|
def test_malformed_ipv6_url_returns_unparseable(self) -> None:
|
||||||
"""Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError."""
|
"""Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError."""
|
||||||
# urlparse("https://[invalid") parses without error, but accessing .hostname
|
# urlparse("https://[invalid") parses without error, but accessing .hostname
|
||||||
# raises ValueError: Invalid IPv6 URL — this triggers the except branch.
|
# raises ValueError: Invalid IPv6 URL — this triggers the except branch.
|
||||||
result = _safe_display_url("https://[invalid")
|
result = safe_display_url("https://[invalid")
|
||||||
assert result == "<unparseable>"
|
assert result == "<unparseable>"
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class TestUnraidInfoTool:
|
|||||||
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("unexpected")
|
_mock_graphql.side_effect = RuntimeError("unexpected")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="unexpected"):
|
with pytest.raises(ToolError, match="Failed to execute info/online"):
|
||||||
await tool_fn(action="online")
|
await tool_fn(action="online")
|
||||||
|
|
||||||
async def test_metrics(self, _mock_graphql: AsyncMock) -> None:
|
async def test_metrics(self, _mock_graphql: AsyncMock) -> None:
|
||||||
@@ -201,6 +201,7 @@ class TestUnraidInfoTool:
|
|||||||
_mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]}
|
_mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="services")
|
result = await tool_fn(action="services")
|
||||||
|
assert "services" in result
|
||||||
assert len(result["services"]) == 1
|
assert len(result["services"]) == 1
|
||||||
assert result["services"][0]["name"] == "docker"
|
assert result["services"][0]["name"] == "docker"
|
||||||
|
|
||||||
@@ -225,6 +226,7 @@ class TestUnraidInfoTool:
|
|||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="servers")
|
result = await tool_fn(action="servers")
|
||||||
|
assert "servers" in result
|
||||||
assert len(result["servers"]) == 1
|
assert len(result["servers"]) == 1
|
||||||
assert result["servers"][0]["name"] == "tower"
|
assert result["servers"][0]["name"] == "tower"
|
||||||
|
|
||||||
@@ -248,6 +250,7 @@ class TestUnraidInfoTool:
|
|||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="ups_devices")
|
result = await tool_fn(action="ups_devices")
|
||||||
|
assert "ups_devices" in result
|
||||||
assert len(result["ups_devices"]) == 1
|
assert len(result["ups_devices"]) == 1
|
||||||
assert result["ups_devices"][0]["model"] == "APC"
|
assert result["ups_devices"][0]["model"] == "APC"
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ def _make_tool():
|
|||||||
return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone")
|
return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("_mock_graphql")
|
|
||||||
class TestRcloneValidation:
|
class TestRcloneValidation:
|
||||||
async def test_delete_requires_confirm(self) -> None:
|
async def test_delete_requires_confirm(self) -> None:
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
|
|||||||
@@ -149,6 +149,15 @@ class TestSafeGet:
|
|||||||
result = safe_get({}, "missing", default=[])
|
result = safe_get({}, "missing", default=[])
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
def test_zero_value_not_replaced_by_default(self) -> None:
|
||||||
|
assert safe_get({"temp": 0}, "temp", default="N/A") == 0
|
||||||
|
|
||||||
|
def test_false_value_not_replaced_by_default(self) -> None:
|
||||||
|
assert safe_get({"active": False}, "active", default=True) is False
|
||||||
|
|
||||||
|
def test_empty_string_not_replaced_by_default(self) -> None:
|
||||||
|
assert safe_get({"name": ""}, "name", default="unknown") == ""
|
||||||
|
|
||||||
|
|
||||||
class TestStorageActions:
|
class TestStorageActions:
|
||||||
async def test_shares(self, _mock_graphql: AsyncMock) -> None:
|
async def test_shares(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ class TestCapLogContentSmallData:
|
|||||||
class TestCapLogContentTruncation:
|
class TestCapLogContentTruncation:
|
||||||
"""Content exceeding both byte AND line limits must be truncated to the last N lines."""
|
"""Content exceeding both byte AND line limits must be truncated to the last N lines."""
|
||||||
|
|
||||||
def test_oversized_content_truncated_to_last_n_lines(self) -> None:
|
def test_oversized_content_truncated_and_byte_capped(self) -> None:
|
||||||
# 200 lines, limit 50 lines, byte limit effectively 0 → should keep last 50 lines
|
# 200 lines, tiny byte limit: must keep recent content within byte cap.
|
||||||
lines = [f"line {i}" for i in range(200)]
|
lines = [f"line {i}" for i in range(200)]
|
||||||
data = {"content": "\n".join(lines)}
|
data = {"content": "\n".join(lines)}
|
||||||
with (
|
with (
|
||||||
@@ -70,14 +70,13 @@ class TestCapLogContentTruncation:
|
|||||||
):
|
):
|
||||||
result = _cap_log_content(data)
|
result = _cap_log_content(data)
|
||||||
result_lines = result["content"].splitlines()
|
result_lines = result["content"].splitlines()
|
||||||
assert len(result_lines) == 50
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
# Must be the LAST 50 lines
|
# Must keep the most recent line suffix.
|
||||||
assert result_lines[0] == "line 150"
|
|
||||||
assert result_lines[-1] == "line 199"
|
assert result_lines[-1] == "line 199"
|
||||||
|
|
||||||
def test_content_with_fewer_lines_than_limit_not_truncated(self) -> None:
|
def test_content_with_fewer_lines_than_limit_still_honors_byte_cap(self) -> None:
|
||||||
"""If byte limit exceeded but line count ≤ limit → keep original (not truncated)."""
|
"""If byte limit is exceeded, output must still be capped even with few lines."""
|
||||||
# 30 lines but byte limit 10 and line limit 50 → 30 < 50 so no truncation
|
# 30 lines, byte limit 10, line limit 50 -> must cap bytes regardless of line count
|
||||||
lines = [f"line {i}" for i in range(30)]
|
lines = [f"line {i}" for i in range(30)]
|
||||||
data = {"content": "\n".join(lines)}
|
data = {"content": "\n".join(lines)}
|
||||||
with (
|
with (
|
||||||
@@ -85,8 +84,7 @@ class TestCapLogContentTruncation:
|
|||||||
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
):
|
):
|
||||||
result = _cap_log_content(data)
|
result = _cap_log_content(data)
|
||||||
# Original content preserved
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
assert result["content"] == data["content"]
|
|
||||||
|
|
||||||
def test_non_content_keys_preserved_alongside_truncated_content(self) -> None:
|
def test_non_content_keys_preserved_alongside_truncated_content(self) -> None:
|
||||||
lines = [f"line {i}" for i in range(200)]
|
lines = [f"line {i}" for i in range(200)]
|
||||||
@@ -98,7 +96,7 @@ class TestCapLogContentTruncation:
|
|||||||
result = _cap_log_content(data)
|
result = _cap_log_content(data)
|
||||||
assert result["path"] == "/var/log/syslog"
|
assert result["path"] == "/var/log/syslog"
|
||||||
assert result["total_lines"] == 200
|
assert result["total_lines"] == 200
|
||||||
assert len(result["content"].splitlines()) == 50
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
|
|
||||||
class TestCapLogContentNested:
|
class TestCapLogContentNested:
|
||||||
@@ -112,7 +110,7 @@ class TestCapLogContentNested:
|
|||||||
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
):
|
):
|
||||||
result = _cap_log_content(data)
|
result = _cap_log_content(data)
|
||||||
assert len(result["logFile"]["content"].splitlines()) == 50
|
assert len(result["logFile"]["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
assert result["logFile"]["path"] == "/var/log/syslog"
|
assert result["logFile"]["path"] == "/var/log/syslog"
|
||||||
|
|
||||||
def test_deeply_nested_content_capped(self) -> None:
|
def test_deeply_nested_content_capped(self) -> None:
|
||||||
@@ -123,9 +121,36 @@ class TestCapLogContentNested:
|
|||||||
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
):
|
):
|
||||||
result = _cap_log_content(data)
|
result = _cap_log_content(data)
|
||||||
assert len(result["outer"]["inner"]["content"].splitlines()) == 50
|
assert len(result["outer"]["inner"]["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
def test_nested_non_content_keys_unaffected(self) -> None:
|
def test_nested_non_content_keys_unaffected(self) -> None:
|
||||||
data = {"metrics": {"cpu": 42.5, "memory": 8192}}
|
data = {"metrics": {"cpu": 42.5, "memory": 8192}}
|
||||||
result = _cap_log_content(data)
|
result = _cap_log_content(data)
|
||||||
assert result == data
|
assert result == data
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentSingleMassiveLine:
|
||||||
|
"""A single line larger than the byte cap must be hard-capped at byte level."""
|
||||||
|
|
||||||
|
def test_single_massive_line_hard_caps_bytes(self) -> None:
|
||||||
|
# One line, no newlines, larger than the byte cap.
|
||||||
|
# The while-loop can't reduce it (len(lines) == 1), so the
|
||||||
|
# last-resort byte-slice path at manager.py:65-69 must fire.
|
||||||
|
huge_content = "x" * 200
|
||||||
|
data = {"content": huge_content}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
|
def test_single_massive_line_input_not_mutated(self) -> None:
|
||||||
|
huge_content = "x" * 200
|
||||||
|
data = {"content": huge_content}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
|
||||||
|
):
|
||||||
|
_cap_log_content(data)
|
||||||
|
assert data["content"] == huge_content
|
||||||
|
|||||||
@@ -1,13 +1,6 @@
|
|||||||
"""Unraid MCP Server Package.
|
"""Unraid MCP Server Package."""
|
||||||
|
|
||||||
A modular MCP (Model Context Protocol) server that provides tools to interact
|
from .version import VERSION
|
||||||
with an Unraid server's GraphQL API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
__version__ = VERSION
|
||||||
__version__ = version("unraid-mcp")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
__version__ = "0.0.0"
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class OverwriteFileHandler(logging.FileHandler):
|
|||||||
"""Emit a record, checking file size periodically and overwriting if needed."""
|
"""Emit a record, checking file size periodically and overwriting if needed."""
|
||||||
self._emit_count += 1
|
self._emit_count += 1
|
||||||
if (
|
if (
|
||||||
self._emit_count % self._check_interval == 0
|
(self._emit_count == 1 or self._emit_count % self._check_interval == 0)
|
||||||
and self.stream
|
and self.stream
|
||||||
and hasattr(self.stream, "name")
|
and hasattr(self.stream, "name")
|
||||||
):
|
):
|
||||||
@@ -249,5 +249,3 @@ if FASTMCP_AVAILABLE:
|
|||||||
else:
|
else:
|
||||||
# Fallback to our custom logger if FastMCP is not available
|
# Fallback to our custom logger if FastMCP is not available
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
# Also configure FastMCP logger for consistency
|
|
||||||
configure_fastmcp_logger_with_rich()
|
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ and provides all configuration constants used throughout the application.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from ..version import VERSION as APP_VERSION
|
||||||
|
|
||||||
|
|
||||||
# Get the script directory (config module location)
|
# Get the script directory (config module location)
|
||||||
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
|
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
|
||||||
@@ -31,12 +32,6 @@ for dotenv_path in dotenv_paths:
|
|||||||
load_dotenv(dotenv_path=dotenv_path)
|
load_dotenv(dotenv_path=dotenv_path)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Application Version (single source of truth: pyproject.toml)
|
|
||||||
try:
|
|
||||||
VERSION = version("unraid-mcp")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
VERSION = "0.0.0"
|
|
||||||
|
|
||||||
# Core API Configuration
|
# Core API Configuration
|
||||||
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
|
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
|
||||||
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
|
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
|
||||||
@@ -58,12 +53,18 @@ else: # Path to CA bundle
|
|||||||
# Logging Configuration
|
# Logging Configuration
|
||||||
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
|
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
|
||||||
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
|
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
|
||||||
# Use /app/logs in Docker, project-relative logs/ directory otherwise
|
# Use /.dockerenv as the container indicator for robust Docker detection.
|
||||||
LOGS_DIR = Path("/app/logs") if Path("/app").is_dir() else PROJECT_ROOT / "logs"
|
IS_DOCKER = Path("/.dockerenv").exists()
|
||||||
|
LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs"
|
||||||
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
||||||
|
|
||||||
# Ensure logs directory exists
|
# Ensure logs directory exists; if creation fails, fall back to /tmp.
|
||||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
try:
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError:
|
||||||
|
LOGS_DIR = PROJECT_ROOT / ".cache" / "logs"
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
||||||
|
|
||||||
# HTTP Client Configuration
|
# HTTP Client Configuration
|
||||||
TIMEOUT_CONFIG = {
|
TIMEOUT_CONFIG = {
|
||||||
@@ -109,3 +110,5 @@ def get_config_summary() -> dict[str, Any]:
|
|||||||
"config_valid": is_valid,
|
"config_valid": is_valid,
|
||||||
"missing_config": missing if not is_valid else None,
|
"missing_config": missing if not is_valid else None,
|
||||||
}
|
}
|
||||||
|
# Re-export application version from a single source of truth.
|
||||||
|
VERSION = APP_VERSION
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ to the Unraid API with proper timeout handling and error management.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any, Final
|
from typing import Any, Final
|
||||||
|
|
||||||
@@ -47,14 +48,14 @@ def _is_sensitive_key(key: str) -> bool:
|
|||||||
return any(s in key_lower for s in _SENSITIVE_KEYS)
|
return any(s in key_lower for s in _SENSITIVE_KEYS)
|
||||||
|
|
||||||
|
|
||||||
def _redact_sensitive(obj: Any) -> Any:
|
def redact_sensitive(obj: Any) -> Any:
|
||||||
"""Recursively redact sensitive values from nested dicts/lists."""
|
"""Recursively redact sensitive values from nested dicts/lists."""
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return {
|
return {
|
||||||
k: ("***" if _is_sensitive_key(k) else _redact_sensitive(v)) for k, v in obj.items()
|
k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()
|
||||||
}
|
}
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return [_redact_sensitive(item) for item in obj]
|
return [redact_sensitive(item) for item in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@@ -139,6 +140,7 @@ _CACHEABLE_QUERY_PREFIXES = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_CACHE_TTL_SECONDS = 60.0
|
_CACHE_TTL_SECONDS = 60.0
|
||||||
|
_OPERATION_NAME_PATTERN = re.compile(r"^(?:query\s+)?([_A-Za-z][_0-9A-Za-z]*)\b")
|
||||||
|
|
||||||
|
|
||||||
class _QueryCache:
|
class _QueryCache:
|
||||||
@@ -160,9 +162,13 @@ class _QueryCache:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cacheable(query: str) -> bool:
|
def is_cacheable(query: str) -> bool:
|
||||||
"""Check if a query is eligible for caching based on its operation name."""
|
"""Check if a query is eligible for caching based on its operation name."""
|
||||||
if query.lstrip().startswith("mutation"):
|
normalized = query.lstrip()
|
||||||
|
if normalized.startswith("mutation"):
|
||||||
return False
|
return False
|
||||||
return any(prefix in query for prefix in _CACHEABLE_QUERY_PREFIXES)
|
match = _OPERATION_NAME_PATTERN.match(normalized)
|
||||||
|
if not match:
|
||||||
|
return False
|
||||||
|
return match.group(1) in _CACHEABLE_QUERY_PREFIXES
|
||||||
|
|
||||||
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
|
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
"""Return cached result if present and not expired, else None."""
|
"""Return cached result if present and not expired, else None."""
|
||||||
@@ -324,7 +330,7 @@ async def make_graphql_request(
|
|||||||
logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:")
|
logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:")
|
||||||
logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query
|
logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query
|
||||||
if variables:
|
if variables:
|
||||||
logger.debug(f"Variables: {_redact_sensitive(variables)}")
|
logger.debug(f"Variables: {redact_sensitive(variables)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Rate limit: consume a token before making the request
|
# Rate limit: consume a token before making the request
|
||||||
|
|||||||
@@ -45,13 +45,12 @@ def tool_error_handler(
|
|||||||
except ToolError:
|
except ToolError:
|
||||||
raise
|
raise
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
logger.error(
|
logger.exception(f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit")
|
||||||
f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
|
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
|
||||||
) from e
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True)
|
logger.exception(f"Error in unraid_{tool_name} action={action}")
|
||||||
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e
|
raise ToolError(
|
||||||
|
f"Failed to execute {tool_name}/{action}. Check server logs for details."
|
||||||
|
) from e
|
||||||
|
|||||||
@@ -20,33 +20,10 @@ class SubscriptionData:
|
|||||||
last_updated: datetime # Must be timezone-aware (UTC)
|
last_updated: datetime # Must be timezone-aware (UTC)
|
||||||
subscription_type: str
|
subscription_type: str
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
@dataclass(slots=True)
|
if self.last_updated.tzinfo is None:
|
||||||
class SystemHealth:
|
raise ValueError(
|
||||||
"""Container for system health status information.
|
"last_updated must be timezone-aware; use datetime.now(UTC)"
|
||||||
|
)
|
||||||
Note: last_checked must be timezone-aware (use datetime.now(UTC)).
|
if not self.subscription_type.strip():
|
||||||
"""
|
raise ValueError("subscription_type must be a non-empty string")
|
||||||
|
|
||||||
is_healthy: bool
|
|
||||||
issues: list[str]
|
|
||||||
warnings: list[str]
|
|
||||||
last_checked: datetime # Must be timezone-aware (UTC)
|
|
||||||
component_status: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class APIResponse:
|
|
||||||
"""Container for standardized API response data."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
error: str | None = None
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Type aliases for common data structures
|
|
||||||
ConfigValue = str | int | bool | float | None
|
|
||||||
ConfigDict = dict[str, ConfigValue]
|
|
||||||
GraphQLVariables = dict[str, Any]
|
|
||||||
HealthStatus = dict[str, str | bool | int | list[Any]]
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shared utility functions for Unraid MCP tools."""
|
"""Shared utility functions for Unraid MCP tools."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
||||||
@@ -45,6 +46,25 @@ def format_bytes(bytes_value: int | None) -> str:
|
|||||||
return f"{value:.2f} EB"
|
return f"{value:.2f} EB"
|
||||||
|
|
||||||
|
|
||||||
|
def safe_display_url(url: str | None) -> str | None:
|
||||||
|
"""Return a redacted URL showing only scheme + host + port.
|
||||||
|
|
||||||
|
Strips path, query parameters, credentials, and fragments to avoid
|
||||||
|
leaking internal network topology or embedded secrets (CWE-200).
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname or "unknown"
|
||||||
|
if parsed.port:
|
||||||
|
return f"{parsed.scheme}://{host}:{parsed.port}"
|
||||||
|
return f"{parsed.scheme}://{host}"
|
||||||
|
except ValueError:
|
||||||
|
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
|
||||||
|
return "<unparseable>"
|
||||||
|
|
||||||
|
|
||||||
def format_kb(k: Any) -> str:
|
def format_kb(k: Any) -> str:
|
||||||
"""Format kilobyte values into human-readable sizes.
|
"""Format kilobyte values into human-readable sizes.
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from websockets.typing import Subprotocol
|
|||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError
|
||||||
|
from ..core.utils import safe_display_url
|
||||||
from .manager import subscription_manager
|
from .manager import subscription_manager
|
||||||
from .resources import ensure_subscriptions_started
|
from .resources import ensure_subscriptions_started
|
||||||
from .utils import build_ws_ssl_context, build_ws_url
|
from .utils import build_ws_ssl_context, build_ws_url
|
||||||
@@ -162,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
"note": "Connection successful, subscription may be waiting for events",
|
"note": "Connection successful, subscription may be waiting for events",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except ToolError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
|
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
|
||||||
return {"error": str(e), "query_tested": subscription_query}
|
return {"error": str(e), "query_tested": subscription_query}
|
||||||
@@ -193,7 +196,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
"environment": {
|
"environment": {
|
||||||
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
||||||
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
||||||
"unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None,
|
"unraid_api_url": safe_display_url(UNRAID_API_URL),
|
||||||
"api_key_configured": bool(UNRAID_API_KEY),
|
"api_key_configured": bool(UNRAID_API_KEY),
|
||||||
"websocket_url": None,
|
"websocket_url": None,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from websockets.typing import Subprotocol
|
|||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..config.settings import UNRAID_API_KEY
|
from ..config.settings import UNRAID_API_KEY
|
||||||
from ..core.client import _redact_sensitive
|
from ..core.client import redact_sensitive
|
||||||
from ..core.types import SubscriptionData
|
from ..core.types import SubscriptionData
|
||||||
from .utils import build_ws_ssl_context, build_ws_url
|
from .utils import build_ws_ssl_context, build_ws_url
|
||||||
|
|
||||||
@@ -36,8 +36,7 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
|||||||
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
||||||
most recent _MAX_RESOURCE_DATA_LINES lines.
|
most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||||
|
|
||||||
Note: single lines larger than _MAX_RESOURCE_DATA_BYTES are not split and
|
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
|
||||||
will still be stored at full size; only multi-line content is truncated.
|
|
||||||
"""
|
"""
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, Any] = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
@@ -49,15 +48,31 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
|||||||
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
||||||
):
|
):
|
||||||
lines = value.splitlines()
|
lines = value.splitlines()
|
||||||
|
original_line_count = len(lines)
|
||||||
|
|
||||||
|
# Keep most recent lines first.
|
||||||
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
||||||
truncated = "\n".join(lines[-_MAX_RESOURCE_DATA_LINES:])
|
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
|
||||||
logger.warning(
|
|
||||||
f"[RESOURCE] Capped log content from {len(lines)} to "
|
# Enforce byte cap while preserving whole-line boundaries where possible.
|
||||||
f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)"
|
truncated = "\n".join(lines)
|
||||||
|
truncated_bytes = truncated.encode("utf-8", errors="replace")
|
||||||
|
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
|
||||||
|
lines = lines[1:]
|
||||||
|
truncated = "\n".join(lines)
|
||||||
|
truncated_bytes = truncated.encode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# Last resort: if a single line still exceeds cap, hard-cap bytes.
|
||||||
|
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
|
||||||
|
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
)
|
)
|
||||||
result[key] = truncated
|
|
||||||
else:
|
logger.warning(
|
||||||
result[key] = value
|
f"[RESOURCE] Capped log content from {original_line_count} to "
|
||||||
|
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
|
||||||
|
)
|
||||||
|
result[key] = truncated
|
||||||
else:
|
else:
|
||||||
result[key] = value
|
result[key] = value
|
||||||
return result
|
return result
|
||||||
@@ -148,6 +163,7 @@ class SubscriptionManager:
|
|||||||
# Reset connection tracking
|
# Reset connection tracking
|
||||||
self.reconnect_attempts[subscription_name] = 0
|
self.reconnect_attempts[subscription_name] = 0
|
||||||
self.connection_states[subscription_name] = "starting"
|
self.connection_states[subscription_name] = "starting"
|
||||||
|
self._connection_start_times.pop(subscription_name, None)
|
||||||
|
|
||||||
async with self.subscription_lock:
|
async with self.subscription_lock:
|
||||||
try:
|
try:
|
||||||
@@ -181,6 +197,7 @@ class SubscriptionManager:
|
|||||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
|
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
|
||||||
del self.active_subscriptions[subscription_name]
|
del self.active_subscriptions[subscription_name]
|
||||||
self.connection_states[subscription_name] = "stopped"
|
self.connection_states[subscription_name] = "stopped"
|
||||||
|
self._connection_start_times.pop(subscription_name, None)
|
||||||
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
|
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
|
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
|
||||||
@@ -322,7 +339,7 @@ class SubscriptionManager:
|
|||||||
)
|
)
|
||||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
|
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[SUBSCRIPTION:{subscription_name}] Variables: {_redact_sensitive(variables)}"
|
f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
await websocket.send(json.dumps(subscription_message))
|
await websocket.send(json.dumps(subscription_message))
|
||||||
@@ -431,7 +448,8 @@ class SubscriptionManager:
|
|||||||
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
|
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[DATA:{subscription_name}] Error processing message: {e}"
|
f"[DATA:{subscription_name}] Error processing message: {e}",
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
msg_preview = (
|
msg_preview = (
|
||||||
message[:200]
|
message[:200]
|
||||||
@@ -461,14 +479,22 @@ class SubscriptionManager:
|
|||||||
self.connection_states[subscription_name] = "invalid_uri"
|
self.connection_states[subscription_name] = "invalid_uri"
|
||||||
break # Don't retry on invalid URI
|
break # Don't retry on invalid URI
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# Non-retryable configuration error (e.g. UNRAID_API_URL not set)
|
||||||
|
error_msg = f"Configuration error: {e}"
|
||||||
|
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
|
||||||
|
self.last_error[subscription_name] = error_msg
|
||||||
|
self.connection_states[subscription_name] = "error"
|
||||||
|
break # Don't retry on configuration errors
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Unexpected error: {e}"
|
error_msg = f"Unexpected error: {e}"
|
||||||
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
|
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True)
|
||||||
self.last_error[subscription_name] = error_msg
|
self.last_error[subscription_name] = error_msg
|
||||||
self.connection_states[subscription_name] = "error"
|
self.connection_states[subscription_name] = "error"
|
||||||
|
|
||||||
# Check if connection was stable before deciding on retry behavior
|
# Check if connection was stable before deciding on retry behavior
|
||||||
start_time = self._connection_start_times.get(subscription_name)
|
start_time = self._connection_start_times.pop(subscription_name, None)
|
||||||
if start_time is not None:
|
if start_time is not None:
|
||||||
connected_duration = time.monotonic() - start_time
|
connected_duration = time.monotonic() - start_time
|
||||||
if connected_duration >= _STABLE_CONNECTION_SECONDS:
|
if connected_duration >= _STABLE_CONNECTION_SECONDS:
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ async def autostart_subscriptions() -> None:
|
|||||||
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
|
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
|
||||||
|
raise # Propagate so ensure_subscriptions_started doesn't mark as started
|
||||||
|
|
||||||
# Optional log file subscription
|
# Optional log file subscription
|
||||||
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")
|
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
Provides the `unraid_array` tool with 5 actions for parity check management.
|
Provides the `unraid_array` tool with 5 actions for parity check management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -53,6 +53,14 @@ ARRAY_ACTIONS = Literal[
|
|||||||
"parity_status",
|
"parity_status",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(ARRAY_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(ARRAY_ACTIONS))
|
||||||
|
_extra = set(get_args(ARRAY_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ARRAY_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_array_tool(mcp: FastMCP) -> None:
|
def register_array_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_array tool with the FastMCP instance."""
|
"""Register the unraid_array tool with the FastMCP instance."""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ logs, networks, and update management.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -135,6 +135,14 @@ DOCKER_ACTIONS = Literal[
|
|||||||
"check_updates",
|
"check_updates",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(DOCKER_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(DOCKER_ACTIONS))
|
||||||
|
_extra = set(get_args(DOCKER_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"DOCKER_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
# Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local")
|
# Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local")
|
||||||
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
|
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
|
||||||
|
|
||||||
@@ -199,11 +207,6 @@ def get_available_container_names(containers: list[dict[str, Any]]) -> list[str]
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
def _looks_like_container_id(identifier: str) -> bool:
|
|
||||||
"""Check if an identifier looks like a container ID (full or short hex prefix)."""
|
|
||||||
return bool(_DOCKER_ID_PATTERN.match(identifier) or _DOCKER_SHORT_ID_PATTERN.match(identifier))
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
|
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
|
||||||
"""Resolve a container name/identifier to its actual PrefixedID.
|
"""Resolve a container name/identifier to its actual PrefixedID.
|
||||||
|
|
||||||
@@ -233,12 +236,21 @@ async def _resolve_container_id(container_id: str, *, strict: bool = False) -> s
|
|||||||
# Short hex prefix: match by ID prefix before trying name matching
|
# Short hex prefix: match by ID prefix before trying name matching
|
||||||
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
|
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
|
||||||
id_lower = container_id.lower()
|
id_lower = container_id.lower()
|
||||||
|
matches: list[dict[str, Any]] = []
|
||||||
for c in containers:
|
for c in containers:
|
||||||
cid = (c.get("id") or "").lower()
|
cid = (c.get("id") or "").lower()
|
||||||
if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower):
|
if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower):
|
||||||
actual_id = str(c.get("id", ""))
|
matches.append(c)
|
||||||
logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'")
|
if len(matches) == 1:
|
||||||
return actual_id
|
actual_id = str(matches[0].get("id", ""))
|
||||||
|
logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'")
|
||||||
|
return actual_id
|
||||||
|
if len(matches) > 1:
|
||||||
|
candidate_ids = [str(c.get("id", "")) for c in matches[:5]]
|
||||||
|
raise ToolError(
|
||||||
|
f"Short container ID prefix '{container_id}' is ambiguous. "
|
||||||
|
f"Matches: {', '.join(candidate_ids)}. Use a longer ID or exact name."
|
||||||
|
)
|
||||||
|
|
||||||
resolved = find_container_by_identifier(container_id, containers, strict=strict)
|
resolved = find_container_by_identifier(container_id, containers, strict=strict)
|
||||||
if resolved:
|
if resolved:
|
||||||
@@ -303,7 +315,7 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
|||||||
if action == "network_details" and not network_id:
|
if action == "network_details" and not network_id:
|
||||||
raise ToolError("network_id is required for 'network_details' action")
|
raise ToolError("network_id is required for 'network_details' action")
|
||||||
|
|
||||||
if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES:
|
if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
|
||||||
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||||
|
|
||||||
with tool_error_handler("docker", action, logger):
|
with tool_error_handler("docker", action, logger):
|
||||||
@@ -335,12 +347,12 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
if action == "networks":
|
if action == "networks":
|
||||||
data = await make_graphql_request(QUERIES["networks"])
|
data = await make_graphql_request(QUERIES["networks"])
|
||||||
networks = data.get("dockerNetworks", [])
|
networks = safe_get(data, "dockerNetworks", default=[])
|
||||||
return {"networks": networks}
|
return {"networks": networks}
|
||||||
|
|
||||||
if action == "network_details":
|
if action == "network_details":
|
||||||
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
|
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
|
||||||
return dict(data.get("dockerNetwork") or {})
|
return dict(safe_get(data, "dockerNetwork", default={}) or {})
|
||||||
|
|
||||||
if action == "port_conflicts":
|
if action == "port_conflicts":
|
||||||
data = await make_graphql_request(QUERIES["port_conflicts"])
|
data = await make_graphql_request(QUERIES["port_conflicts"])
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ connection testing, and subscription diagnostics.
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -21,31 +20,21 @@ from ..config.settings import (
|
|||||||
)
|
)
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError, tool_error_handler
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
from ..core.utils import safe_display_url
|
||||||
|
|
||||||
def _safe_display_url(url: str | None) -> str | None:
|
|
||||||
"""Return a redacted URL showing only scheme + host + port.
|
|
||||||
|
|
||||||
Strips path, query parameters, credentials, and fragments to avoid
|
|
||||||
leaking internal network topology or embedded secrets (CWE-200).
|
|
||||||
"""
|
|
||||||
if not url:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = urlparse(url)
|
|
||||||
host = parsed.hostname or "unknown"
|
|
||||||
if parsed.port:
|
|
||||||
return f"{parsed.scheme}://{host}:{parsed.port}"
|
|
||||||
return f"{parsed.scheme}://{host}"
|
|
||||||
except ValueError:
|
|
||||||
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
|
|
||||||
return "<unparseable>"
|
|
||||||
|
|
||||||
|
|
||||||
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
|
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
|
||||||
|
|
||||||
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
|
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
|
||||||
|
|
||||||
|
if set(get_args(HEALTH_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(HEALTH_ACTIONS))
|
||||||
|
_extra = set(get_args(HEALTH_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
"HEALTH_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing in HEALTH_ACTIONS: {_missing}; extra in HEALTH_ACTIONS: {_extra}"
|
||||||
|
)
|
||||||
|
|
||||||
# Severity ordering: only upgrade, never downgrade
|
# Severity ordering: only upgrade, never downgrade
|
||||||
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
|
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
|
||||||
|
|
||||||
@@ -149,7 +138,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
|||||||
if info:
|
if info:
|
||||||
health_info["unraid_system"] = {
|
health_info["unraid_system"] = {
|
||||||
"status": "connected",
|
"status": "connected",
|
||||||
"url": _safe_display_url(UNRAID_API_URL),
|
"url": safe_display_url(UNRAID_API_URL),
|
||||||
"machine_id": info.get("machineId"),
|
"machine_id": info.get("machineId"),
|
||||||
"version": info.get("versions", {}).get("unraid"),
|
"version": info.get("versions", {}).get("unraid"),
|
||||||
"uptime": info.get("os", {}).get("uptime"),
|
"uptime": info.get("os", {}).get("uptime"),
|
||||||
@@ -220,7 +209,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Intentionally broad: health checks must always return a result,
|
# Intentionally broad: health checks must always return a result,
|
||||||
# even on unexpected failures, so callers never get an unhandled exception.
|
# even on unexpected failures, so callers never get an unhandled exception.
|
||||||
logger.error(f"Health check failed: {e}")
|
logger.error(f"Health check failed: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
"status": "unhealthy",
|
"status": "unhealthy",
|
||||||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||||
@@ -293,10 +282,7 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
return {
|
raise ToolError("Subscription modules not available") from e
|
||||||
"error": "Subscription modules not available",
|
|
||||||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e
|
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Provides the `unraid_info` tool with 19 read-only actions for retrieving
|
|||||||
system information, array status, network config, and server metadata.
|
system information, array status, network config, and server metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -180,9 +180,9 @@ INFO_ACTIONS = Literal[
|
|||||||
"ups_config",
|
"ups_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
if set(INFO_ACTIONS.__args__) != ALL_ACTIONS:
|
if set(get_args(INFO_ACTIONS)) != ALL_ACTIONS:
|
||||||
_missing = ALL_ACTIONS - set(INFO_ACTIONS.__args__)
|
_missing = ALL_ACTIONS - set(get_args(INFO_ACTIONS))
|
||||||
_extra = set(INFO_ACTIONS.__args__) - ALL_ACTIONS
|
_extra = set(get_args(INFO_ACTIONS)) - ALL_ACTIONS
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"QUERIES keys and INFO_ACTIONS are out of sync. "
|
f"QUERIES keys and INFO_ACTIONS are out of sync. "
|
||||||
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
@@ -415,7 +415,8 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
if action in list_actions:
|
if action in list_actions:
|
||||||
response_key, output_key = list_actions[action]
|
response_key, output_key = list_actions[action]
|
||||||
items = data.get(response_key) or []
|
items = data.get(response_key) or []
|
||||||
return {output_key: items}
|
normalized_items = list(items) if isinstance(items, list) else []
|
||||||
|
return {output_key: normalized_items}
|
||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Provides the `unraid_keys` tool with 5 actions for listing, viewing,
|
|||||||
creating, updating, and deleting API keys.
|
creating, updating, and deleting API keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -55,6 +55,14 @@ KEY_ACTIONS = Literal[
|
|||||||
"delete",
|
"delete",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(KEY_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(KEY_ACTIONS))
|
||||||
|
_extra = set(get_args(KEY_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"KEY_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_keys_tool(mcp: FastMCP) -> None:
|
def register_keys_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_keys tool with the FastMCP instance."""
|
"""Register the unraid_keys tool with the FastMCP instance."""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Provides the `unraid_notifications` tool with 9 actions for viewing,
|
|||||||
creating, archiving, and deleting system notifications.
|
creating, archiving, and deleting system notifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -91,6 +91,14 @@ NOTIFICATION_ACTIONS = Literal[
|
|||||||
"archive_all",
|
"archive_all",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(NOTIFICATION_ACTIONS))
|
||||||
|
_extra = set(get_args(NOTIFICATION_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"NOTIFICATION_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_notifications_tool(mcp: FastMCP) -> None:
|
def register_notifications_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_notifications tool with the FastMCP instance."""
|
"""Register the unraid_notifications tool with the FastMCP instance."""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -50,10 +50,18 @@ RCLONE_ACTIONS = Literal[
|
|||||||
"delete_remote",
|
"delete_remote",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(RCLONE_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(RCLONE_ACTIONS))
|
||||||
|
_extra = set(get_args(RCLONE_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"RCLONE_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
# Max config entries to prevent abuse
|
# Max config entries to prevent abuse
|
||||||
_MAX_CONFIG_KEYS = 50
|
_MAX_CONFIG_KEYS = 50
|
||||||
# Pattern for suspicious key names (path traversal, shell metacharacters)
|
# Pattern for suspicious key names (path traversal, shell metacharacters)
|
||||||
_DANGEROUS_KEY_PATTERN = re.compile(r"[.]{2}|[/\\;|`$(){}]")
|
_DANGEROUS_KEY_PATTERN = re.compile(r"\.\.|[/\\;|`$(){}]")
|
||||||
# Max length for individual config values
|
# Max length for individual config values
|
||||||
_MAX_VALUE_LENGTH = 4096
|
_MAX_VALUE_LENGTH = 4096
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ unassigned devices, log files, and log content retrieval.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -69,6 +69,14 @@ STORAGE_ACTIONS = Literal[
|
|||||||
"logs",
|
"logs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(STORAGE_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(STORAGE_ACTIONS))
|
||||||
|
_extra = set(get_args(STORAGE_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"STORAGE_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_storage_tool(mcp: FastMCP) -> None:
|
def register_storage_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_storage tool with the FastMCP instance."""
|
"""Register the unraid_storage tool with the FastMCP instance."""
|
||||||
@@ -96,7 +104,7 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
|||||||
if action == "disk_details" and not disk_id:
|
if action == "disk_details" and not disk_id:
|
||||||
raise ToolError("disk_id is required for 'disk_details' action")
|
raise ToolError("disk_id is required for 'disk_details' action")
|
||||||
|
|
||||||
if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES:
|
if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
|
||||||
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||||
|
|
||||||
if action == "logs":
|
if action == "logs":
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Provides the `unraid_vm` tool with 9 actions for VM lifecycle management
|
|||||||
including start, stop, pause, resume, force stop, reboot, and reset.
|
including start, stop, pause, resume, force stop, reboot, and reset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -73,6 +73,14 @@ VM_ACTIONS = Literal[
|
|||||||
|
|
||||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||||
|
|
||||||
|
if set(get_args(VM_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(VM_ACTIONS))
|
||||||
|
_extra = set(get_args(VM_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"VM_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_vm_tool(mcp: FastMCP) -> None:
|
def register_vm_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_vm tool with the FastMCP instance."""
|
"""Register the unraid_vm tool with the FastMCP instance."""
|
||||||
|
|||||||
11
unraid_mcp/version.py
Normal file
11
unraid_mcp/version.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Application version helpers."""
|
||||||
|
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VERSION"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
VERSION = version("unraid-mcp")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
VERSION = "0.0.0"
|
||||||
Reference in New Issue
Block a user