diff --git a/tests/safety/test_destructive_guards.py b/tests/safety/test_destructive_guards.py index 14d66f3..43bf230 100644 --- a/tests/safety/test_destructive_guards.py +++ b/tests/safety/test_destructive_guards.py @@ -39,7 +39,7 @@ KNOWN_DESTRUCTIVE: dict[str, dict[str, set[str]]] = { "module": "unraid_mcp.tools.docker", "register_fn": "register_docker_tool", "tool_name": "unraid_docker", - "actions": {"remove"}, + "actions": {"remove", "update_all"}, "runtime_set": DOCKER_DESTRUCTIVE, }, "vm": { @@ -143,6 +143,7 @@ class TestDestructiveActionRegistries: _DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [ # Docker ("docker", "remove", {"container_id": "abc123"}), + ("docker", "update_all", {}), # VM ("vm", "force_stop", {"vm_id": "test-vm-uuid"}), ("vm", "reset", {"vm_id": "test-vm-uuid"}), diff --git a/tests/test_client.py b/tests/test_client.py index 9208d76..c90f797 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,7 @@ """Tests for unraid_mcp.core.client — GraphQL client infrastructure.""" import json +import time from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -9,6 +10,8 @@ import pytest from unraid_mcp.core.client import ( DEFAULT_TIMEOUT, DISK_TIMEOUT, + _QueryCache, + _RateLimiter, _redact_sensitive, is_idempotent_error, make_graphql_request, @@ -464,3 +467,231 @@ class TestGraphQLErrorHandling: pytest.raises(ToolError, match="GraphQL API error"), ): await make_graphql_request("{ info }") + + +# --------------------------------------------------------------------------- +# _RateLimiter +# --------------------------------------------------------------------------- + + +class TestRateLimiter: + """Unit tests for the token bucket rate limiter.""" + + async def test_acquire_consumes_one_token(self) -> None: + limiter = _RateLimiter(max_tokens=10, refill_rate=1.0) + initial = limiter.tokens + await limiter.acquire() + assert limiter.tokens == initial - 1 + + async def test_acquire_succeeds_when_tokens_available(self) -> None: + limiter = _RateLimiter(max_tokens=5, refill_rate=1.0) + # Should complete without sleeping + for _ in range(5): + await limiter.acquire() + # _refill() runs during each acquire() call and adds a tiny time-based + # amount; check < 1.0 (not enough for another immediate request) rather + # than == 0.0 to avoid flakiness from timing. + assert limiter.tokens < 1.0 + + async def test_tokens_do_not_exceed_max(self) -> None: + limiter = _RateLimiter(max_tokens=10, refill_rate=1.0) + # Force refill with large elapsed time + limiter.last_refill = time.monotonic() - 100.0 # 100 seconds ago + limiter._refill() + assert limiter.tokens == 10.0 # Capped at max_tokens + + async def test_refill_adds_tokens_based_on_elapsed(self) -> None: + limiter = _RateLimiter(max_tokens=100, refill_rate=10.0) + limiter.tokens = 0.0 + limiter.last_refill = time.monotonic() - 1.0 # 1 second ago + limiter._refill() + # Should have refilled ~10 tokens (10.0 rate * 1.0 sec) + assert 9.5 < limiter.tokens < 10.5 + + async def test_acquire_sleeps_when_no_tokens(self) -> None: + """When tokens are exhausted, acquire should sleep before consuming.""" + limiter = _RateLimiter(max_tokens=1, refill_rate=1.0) + limiter.tokens = 0.0 + + sleep_calls = [] + + async def fake_sleep(duration: float) -> None: + sleep_calls.append(duration) + # Simulate refill by advancing last_refill so tokens replenish + limiter.tokens = 1.0 + limiter.last_refill = time.monotonic() + + with patch("unraid_mcp.core.client.asyncio.sleep", side_effect=fake_sleep): + await limiter.acquire() + + assert len(sleep_calls) == 1 + assert sleep_calls[0] > 0 + + async def test_default_params_match_api_limits(self) -> None: + """Default rate limiter must use 90 tokens at 9.0/sec (10% headroom from 100/10s).""" + limiter = _RateLimiter() + assert limiter.max_tokens == 90 + assert limiter.refill_rate == 9.0 + + +# --------------------------------------------------------------------------- +# _QueryCache +# --------------------------------------------------------------------------- + + +class TestQueryCache: + """Unit tests for the TTL query cache.""" + + def test_miss_on_empty_cache(self) -> None: + cache = _QueryCache() + assert cache.get("{ info }", None) is None + + def test_put_and_get_hit(self) -> None: + cache = _QueryCache() + data = {"result": "ok"} + cache.put("GetNetworkConfig { }", None, data) + result = cache.get("GetNetworkConfig { }", None) + assert result == data + + def test_expired_entry_returns_none(self) -> None: + cache = _QueryCache() + data = {"result": "ok"} + cache.put("GetNetworkConfig { }", None, data) + # Manually expire the entry + key = cache._cache_key("GetNetworkConfig { }", None) + cache._store[key] = (time.monotonic() - 1.0, data) # expired 1 sec ago + assert cache.get("GetNetworkConfig { }", None) is None + + def test_invalidate_all_clears_store(self) -> None: + cache = _QueryCache() + cache.put("GetNetworkConfig { }", None, {"x": 1}) + cache.put("GetOwner { }", None, {"y": 2}) + assert len(cache._store) == 2 + cache.invalidate_all() + assert len(cache._store) == 0 + + def test_variables_affect_cache_key(self) -> None: + """Different variables produce different cache keys.""" + cache = _QueryCache() + q = "GetNetworkConfig($id: ID!) { network(id: $id) { name } }" + cache.put(q, {"id": "1"}, {"name": "eth0"}) + cache.put(q, {"id": "2"}, {"name": "eth1"}) + assert cache.get(q, {"id": "1"}) == {"name": "eth0"} + assert cache.get(q, {"id": "2"}) == {"name": "eth1"} + + def test_is_cacheable_returns_true_for_known_prefixes(self) -> None: + assert _QueryCache.is_cacheable("GetNetworkConfig { ... }") is True + assert _QueryCache.is_cacheable("GetRegistrationInfo { ... }") is True + assert _QueryCache.is_cacheable("GetOwner { ... }") is True + assert _QueryCache.is_cacheable("GetFlash { ... }") is True + + def test_is_cacheable_returns_false_for_mutations(self) -> None: + assert _QueryCache.is_cacheable('mutation { docker { start(id: "x") } }') is False + + def test_is_cacheable_returns_false_for_unlisted_queries(self) -> None: + assert _QueryCache.is_cacheable("{ docker { containers { id } } }") is False + assert _QueryCache.is_cacheable("{ info { os } }") is False + + def test_is_cacheable_mutation_check_is_prefix(self) -> None: + """Queries that start with 'mutation' after whitespace are not cacheable.""" + assert _QueryCache.is_cacheable(" mutation { ... }") is False + + def test_expired_entry_removed_from_store(self) -> None: + """Accessing an expired entry should remove it from the internal store.""" + cache = _QueryCache() + cache.put("GetOwner { }", None, {"owner": "root"}) + key = cache._cache_key("GetOwner { }", None) + cache._store[key] = (time.monotonic() - 1.0, {"owner": "root"}) + assert key in cache._store + cache.get("GetOwner { }", None) # triggers deletion + assert key not in cache._store + + +# --------------------------------------------------------------------------- +# make_graphql_request — 429 retry behavior +# --------------------------------------------------------------------------- + + +class TestRateLimitRetry: + """Tests for the 429 retry loop in make_graphql_request.""" + + @pytest.fixture(autouse=True) + def _patch_config(self): + with ( + patch("unraid_mcp.core.client.UNRAID_API_URL", "https://unraid.local/graphql"), + patch("unraid_mcp.core.client.UNRAID_API_KEY", "test-key"), + patch("unraid_mcp.core.client.asyncio.sleep", new_callable=AsyncMock), + ): + yield + + def _make_429_response(self) -> MagicMock: + resp = MagicMock() + resp.status_code = 429 + resp.raise_for_status = MagicMock() + return resp + + def _make_ok_response(self, data: dict) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + resp.json.return_value = {"data": data} + return resp + + async def test_single_429_then_success_retries(self) -> None: + """One 429 followed by a success should return the data.""" + mock_client = AsyncMock() + mock_client.post.side_effect = [ + self._make_429_response(), + self._make_ok_response({"info": {"os": "Unraid"}}), + ] + + with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client): + result = await make_graphql_request("{ info { os } }") + + assert result == {"info": {"os": "Unraid"}} + assert mock_client.post.call_count == 2 + + async def test_two_429s_then_success(self) -> None: + """Two 429s followed by success returns data after 2 retries.""" + mock_client = AsyncMock() + mock_client.post.side_effect = [ + self._make_429_response(), + self._make_429_response(), + self._make_ok_response({"x": 1}), + ] + + with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client): + result = await make_graphql_request("{ x }") + + assert result == {"x": 1} + assert mock_client.post.call_count == 3 + + async def test_three_429s_raises_tool_error(self) -> None: + """Three consecutive 429s (all retries exhausted) raises ToolError.""" + mock_client = AsyncMock() + mock_client.post.side_effect = [ + self._make_429_response(), + self._make_429_response(), + self._make_429_response(), + ] + + with ( + patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), + pytest.raises(ToolError, match="rate limiting"), + ): + await make_graphql_request("{ info }") + + async def test_rate_limit_error_message_advises_wait(self) -> None: + """The ToolError message should tell the user to wait ~10 seconds.""" + mock_client = AsyncMock() + mock_client.post.side_effect = [ + self._make_429_response(), + self._make_429_response(), + self._make_429_response(), + ] + + with ( + patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), + pytest.raises(ToolError, match="10 seconds"), + ): + await make_graphql_request("{ info }") diff --git a/tests/test_docker.py b/tests/test_docker.py index c725979..c3591ff 100644 --- a/tests/test_docker.py +++ b/tests/test_docker.py @@ -175,7 +175,7 @@ class TestDockerActions: "docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]} } tool_fn = _make_tool() - result = await tool_fn(action="update_all") + result = await tool_fn(action="update_all", confirm=True) assert result["success"] is True assert len(result["containers"]) == 1 @@ -271,10 +271,16 @@ class TestDockerMutationFailures: """update_all with no containers to update.""" _mock_graphql.return_value = {"docker": {"updateAllContainers": []}} tool_fn = _make_tool() - result = await tool_fn(action="update_all") + result = await tool_fn(action="update_all", confirm=True) assert result["success"] is True assert result["containers"] == [] + async def test_update_all_requires_confirm(self, _mock_graphql: AsyncMock) -> None: + """update_all is destructive and requires confirm=True.""" + tool_fn = _make_tool() + with pytest.raises(ToolError, match="destructive"): + await tool_fn(action="update_all") + async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None: """Mid-operation timeout during a docker mutation.""" diff --git a/tests/test_health.py b/tests/test_health.py index b0e978a..de2f835 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -7,6 +7,7 @@ import pytest from conftest import make_tool_fn from unraid_mcp.core.exceptions import ToolError +from unraid_mcp.tools.health import _safe_display_url @pytest.fixture @@ -139,3 +140,55 @@ class TestHealthActions: finally: # Restore cached modules sys.modules.update(cached) + + +# --------------------------------------------------------------------------- +# _safe_display_url — URL redaction helper +# --------------------------------------------------------------------------- + + +class TestSafeDisplayUrl: + """Verify that _safe_display_url strips credentials/path and preserves scheme+host+port.""" + + def test_none_returns_none(self) -> None: + assert _safe_display_url(None) is None + + def test_empty_string_returns_none(self) -> None: + assert _safe_display_url("") is None + + def test_simple_url_scheme_and_host(self) -> None: + assert _safe_display_url("https://unraid.local/graphql") == "https://unraid.local" + + def test_preserves_port(self) -> None: + assert _safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337" + + def test_strips_path(self) -> None: + result = _safe_display_url("http://unraid.local/some/deep/path?query=1") + assert "path" not in result + assert "query" not in result + + def test_strips_credentials(self) -> None: + result = _safe_display_url("https://user:password@unraid.local/graphql") + assert "user" not in result + assert "password" not in result + assert result == "https://unraid.local" + + def test_strips_query_params(self) -> None: + result = _safe_display_url("http://host.local?token=abc&key=xyz") + assert "token" not in result + assert "abc" not in result + + def test_http_scheme_preserved(self) -> None: + result = _safe_display_url("http://10.0.0.1:8080/api") + assert result == "http://10.0.0.1:8080" + + def test_tailscale_url(self) -> None: + result = _safe_display_url("https://100.118.209.1:31337/graphql") + assert result == "https://100.118.209.1:31337" + + def test_malformed_ipv6_url_returns_unparseable(self) -> None: + """Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError.""" + # urlparse("https://[invalid") parses without error, but accessing .hostname + # raises ValueError: Invalid IPv6 URL — this triggers the except branch. + result = _safe_display_url("https://[invalid") + assert result == "" diff --git a/tests/test_notifications.py b/tests/test_notifications.py index af07977..946fed7 100644 --- a/tests/test_notifications.py +++ b/tests/test_notifications.py @@ -151,3 +151,87 @@ class TestNotificationsActions: tool_fn = _make_tool() with pytest.raises(ToolError, match="boom"): await tool_fn(action="overview") + + +class TestNotificationsCreateValidation: + """Tests for importance enum and field length validation added in this PR.""" + + async def test_invalid_importance_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="importance must be one of"): + await tool_fn( + action="create", + title="T", + subject="S", + description="D", + importance="invalid", + ) + + async def test_info_importance_rejected(self, _mock_graphql: AsyncMock) -> None: + """INFO is listed in old docstring examples but rejected by the validator.""" + tool_fn = _make_tool() + with pytest.raises(ToolError, match="importance must be one of"): + await tool_fn( + action="create", + title="T", + subject="S", + description="D", + importance="info", + ) + + async def test_alert_importance_accepted(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = { + "notifications": {"createNotification": {"id": "n:1", "importance": "ALERT"}} + } + tool_fn = _make_tool() + result = await tool_fn( + action="create", title="T", subject="S", description="D", importance="alert" + ) + assert result["success"] is True + + async def test_title_too_long_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="title must be at most 200"): + await tool_fn( + action="create", + title="x" * 201, + subject="S", + description="D", + importance="normal", + ) + + async def test_subject_too_long_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="subject must be at most 500"): + await tool_fn( + action="create", + title="T", + subject="x" * 501, + description="D", + importance="normal", + ) + + async def test_description_too_long_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="description must be at most 2000"): + await tool_fn( + action="create", + title="T", + subject="S", + description="x" * 2001, + importance="normal", + ) + + async def test_title_at_max_accepted(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = { + "notifications": {"createNotification": {"id": "n:1", "importance": "NORMAL"}} + } + tool_fn = _make_tool() + result = await tool_fn( + action="create", + title="x" * 200, + subject="S", + description="D", + importance="normal", + ) + assert result["success"] is True diff --git a/tests/test_rclone.py b/tests/test_rclone.py index 45a0477..caf93cd 100644 --- a/tests/test_rclone.py +++ b/tests/test_rclone.py @@ -100,3 +100,83 @@ class TestRcloneActions: tool_fn = _make_tool() with pytest.raises(ToolError, match="Failed to delete"): await tool_fn(action="delete_remote", name="gdrive", confirm=True) + + +class TestRcloneConfigDataValidation: + """Tests for _validate_config_data security guards.""" + + async def test_path_traversal_in_key_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="disallowed characters"): + await tool_fn( + action="create_remote", + name="r", + provider_type="s3", + config_data={"../evil": "value"}, + ) + + async def test_shell_metachar_in_key_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="disallowed characters"): + await tool_fn( + action="create_remote", + name="r", + provider_type="s3", + config_data={"key;rm": "value"}, + ) + + async def test_too_many_keys_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="max 50"): + await tool_fn( + action="create_remote", + name="r", + provider_type="s3", + config_data={f"key{i}": "v" for i in range(51)}, + ) + + async def test_dict_value_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="string, number, or boolean"): + await tool_fn( + action="create_remote", + name="r", + provider_type="s3", + config_data={"nested": {"key": "val"}}, + ) + + async def test_value_too_long_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="exceeds max length"): + await tool_fn( + action="create_remote", + name="r", + provider_type="s3", + config_data={"key": "x" * 4097}, + ) + + async def test_boolean_value_accepted(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = { + "rclone": {"createRCloneRemote": {"name": "r", "type": "s3"}} + } + tool_fn = _make_tool() + result = await tool_fn( + action="create_remote", + name="r", + provider_type="s3", + config_data={"use_path_style": True}, + ) + assert result["success"] is True + + async def test_int_value_accepted(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = { + "rclone": {"createRCloneRemote": {"name": "r", "type": "sftp"}} + } + tool_fn = _make_tool() + result = await tool_fn( + action="create_remote", + name="r", + provider_type="sftp", + config_data={"port": 22}, + ) + assert result["success"] is True diff --git a/tests/test_storage.py b/tests/test_storage.py index 77d5ea9..eac4c36 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -7,7 +7,7 @@ import pytest from conftest import make_tool_fn from unraid_mcp.core.exceptions import ToolError -from unraid_mcp.core.utils import format_bytes +from unraid_mcp.core.utils import format_bytes, format_kb, safe_get # --- Unit tests for helpers --- @@ -77,6 +77,70 @@ class TestStorageValidation: result = await tool_fn(action="logs", log_path="/var/log/syslog") assert result["content"] == "ok" + async def test_logs_tail_lines_too_large(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="tail_lines must be between"): + await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_001) + + async def test_logs_tail_lines_zero_rejected(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="tail_lines must be between"): + await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=0) + + async def test_logs_tail_lines_at_max_accepted(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = {"logFile": {"path": "/var/log/syslog", "content": "ok"}} + tool_fn = _make_tool() + result = await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_000) + assert result["content"] == "ok" + + +class TestFormatKb: + def test_none_returns_na(self) -> None: + assert format_kb(None) == "N/A" + + def test_invalid_string_returns_na(self) -> None: + assert format_kb("not-a-number") == "N/A" + + def test_kilobytes_range(self) -> None: + assert format_kb(512) == "512 KB" + + def test_megabytes_range(self) -> None: + assert format_kb(2048) == "2.00 MB" + + def test_gigabytes_range(self) -> None: + assert format_kb(1_048_576) == "1.00 GB" + + def test_terabytes_range(self) -> None: + assert format_kb(1_073_741_824) == "1.00 TB" + + def test_boundary_exactly_1024_kb(self) -> None: + # 1024 KB = 1 MB + assert format_kb(1024) == "1.00 MB" + + +class TestSafeGet: + def test_simple_key_access(self) -> None: + assert safe_get({"a": 1}, "a") == 1 + + def test_nested_key_access(self) -> None: + assert safe_get({"a": {"b": "val"}}, "a", "b") == "val" + + def test_missing_key_returns_none(self) -> None: + assert safe_get({"a": 1}, "missing") is None + + def test_none_intermediate_returns_default(self) -> None: + assert safe_get({"a": None}, "a", "b") is None + + def test_custom_default_returned(self) -> None: + assert safe_get({}, "x", default="fallback") == "fallback" + + def test_non_dict_intermediate_returns_default(self) -> None: + assert safe_get({"a": "string"}, "a", "b") is None + + def test_empty_list_default(self) -> None: + result = safe_get({}, "missing", default=[]) + assert result == [] + class TestStorageActions: async def test_shares(self, _mock_graphql: AsyncMock) -> None: diff --git a/tests/test_subscription_manager.py b/tests/test_subscription_manager.py new file mode 100644 index 0000000..53b5080 --- /dev/null +++ b/tests/test_subscription_manager.py @@ -0,0 +1,131 @@ +"""Tests for _cap_log_content in subscriptions/manager.py. + +_cap_log_content is a pure utility that prevents unbounded memory growth from +log subscription data. It must: return a NEW dict (not mutate), recursively +cap nested 'content' fields, and only truncate when both byte limit and line +limit are exceeded. +""" + +from unittest.mock import patch + +from unraid_mcp.subscriptions.manager import _cap_log_content + + +class TestCapLogContentImmutability: + """The function must return a new dict — never mutate the input.""" + + def test_returns_new_dict(self) -> None: + data = {"key": "value"} + result = _cap_log_content(data) + assert result is not data + + def test_input_not_mutated_on_passthrough(self) -> None: + data = {"content": "short text", "other": "value"} + original_content = data["content"] + _cap_log_content(data) + assert data["content"] == original_content + + def test_input_not_mutated_on_truncation(self) -> None: + # Use small limits so the truncation path is exercised + large_content = "\n".join(f"line {i}" for i in range(200)) + data = {"content": large_content} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), + ): + _cap_log_content(data) + # Original data must be unchanged + assert data["content"] == large_content + + +class TestCapLogContentSmallData: + """Content below the byte limit must be returned unchanged.""" + + def test_small_content_unchanged(self) -> None: + data = {"content": "just a few lines\nof log data\n"} + result = _cap_log_content(data) + assert result["content"] == data["content"] + + def test_non_content_keys_passed_through(self) -> None: + data = {"name": "cpu_subscription", "timestamp": "2026-02-18T00:00:00Z"} + result = _cap_log_content(data) + assert result == data + + def test_integer_value_passed_through(self) -> None: + data = {"count": 42, "active": True} + result = _cap_log_content(data) + assert result == data + + +class TestCapLogContentTruncation: + """Content exceeding both byte AND line limits must be truncated to the last N lines.""" + + def test_oversized_content_truncated_to_last_n_lines(self) -> None: + # 200 lines, limit 50 lines, byte limit effectively 0 → should keep last 50 lines + lines = [f"line {i}" for i in range(200)] + data = {"content": "\n".join(lines)} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), + ): + result = _cap_log_content(data) + result_lines = result["content"].splitlines() + assert len(result_lines) == 50 + # Must be the LAST 50 lines + assert result_lines[0] == "line 150" + assert result_lines[-1] == "line 199" + + def test_content_with_fewer_lines_than_limit_not_truncated(self) -> None: + """If byte limit exceeded but line count ≤ limit → keep original (not truncated).""" + # 30 lines but byte limit 10 and line limit 50 → 30 < 50 so no truncation + lines = [f"line {i}" for i in range(30)] + data = {"content": "\n".join(lines)} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), + ): + result = _cap_log_content(data) + # Original content preserved + assert result["content"] == data["content"] + + def test_non_content_keys_preserved_alongside_truncated_content(self) -> None: + lines = [f"line {i}" for i in range(200)] + data = {"content": "\n".join(lines), "path": "/var/log/syslog", "total_lines": 200} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), + ): + result = _cap_log_content(data) + assert result["path"] == "/var/log/syslog" + assert result["total_lines"] == 200 + assert len(result["content"].splitlines()) == 50 + + +class TestCapLogContentNested: + """Nested 'content' fields inside sub-dicts must also be capped recursively.""" + + def test_nested_content_field_capped(self) -> None: + lines = [f"line {i}" for i in range(200)] + data = {"logFile": {"content": "\n".join(lines), "path": "/var/log/syslog"}} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), + ): + result = _cap_log_content(data) + assert len(result["logFile"]["content"].splitlines()) == 50 + assert result["logFile"]["path"] == "/var/log/syslog" + + def test_deeply_nested_content_capped(self) -> None: + lines = [f"line {i}" for i in range(200)] + data = {"outer": {"inner": {"content": "\n".join(lines)}}} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), + ): + result = _cap_log_content(data) + assert len(result["outer"]["inner"]["content"].splitlines()) == 50 + + def test_nested_non_content_keys_unaffected(self) -> None: + data = {"metrics": {"cpu": 42.5, "memory": 8192}} + result = _cap_log_content(data) + assert result == data diff --git a/tests/test_subscription_validation.py b/tests/test_subscription_validation.py new file mode 100644 index 0000000..879919f --- /dev/null +++ b/tests/test_subscription_validation.py @@ -0,0 +1,131 @@ +"""Tests for _validate_subscription_query in diagnostics.py. + +Security-critical: this function is the only guard against arbitrary GraphQL +operations (mutations, queries) being sent over the WebSocket subscription channel. +""" + +import pytest + +from unraid_mcp.core.exceptions import ToolError +from unraid_mcp.subscriptions.diagnostics import ( + _ALLOWED_SUBSCRIPTION_NAMES, + _validate_subscription_query, +) + + +class TestValidateSubscriptionQueryAllowed: + """All whitelisted subscription names must be accepted.""" + + @pytest.mark.parametrize("sub_name", sorted(_ALLOWED_SUBSCRIPTION_NAMES)) + def test_all_allowed_names_accepted(self, sub_name: str) -> None: + query = f"subscription {{ {sub_name} {{ data }} }}" + result = _validate_subscription_query(query) + assert result == sub_name + + def test_returns_extracted_subscription_name(self) -> None: + query = "subscription { cpuSubscription { usage } }" + assert _validate_subscription_query(query) == "cpuSubscription" + + def test_leading_whitespace_accepted(self) -> None: + query = " subscription { memorySubscription { free } }" + assert _validate_subscription_query(query) == "memorySubscription" + + def test_multiline_query_accepted(self) -> None: + query = "subscription {\n logFileSubscription {\n content\n }\n}" + assert _validate_subscription_query(query) == "logFileSubscription" + + def test_case_insensitive_subscription_keyword(self) -> None: + """'SUBSCRIPTION' should be accepted (regex uses IGNORECASE).""" + query = "SUBSCRIPTION { cpuSubscription { usage } }" + assert _validate_subscription_query(query) == "cpuSubscription" + + +class TestValidateSubscriptionQueryForbiddenKeywords: + """Queries containing 'mutation' or 'query' as standalone keywords must be rejected.""" + + def test_mutation_keyword_rejected(self) -> None: + query = 'mutation { docker { start(id: "abc") } }' + with pytest.raises(ToolError, match="must be a subscription"): + _validate_subscription_query(query) + + def test_query_keyword_rejected(self) -> None: + query = "query { info { os { platform } } }" + with pytest.raises(ToolError, match="must be a subscription"): + _validate_subscription_query(query) + + def test_mutation_embedded_in_subscription_rejected(self) -> None: + """'mutation' anywhere in the string triggers rejection.""" + query = "subscription { cpuSubscription { mutation data } }" + with pytest.raises(ToolError, match="must be a subscription"): + _validate_subscription_query(query) + + def test_query_embedded_in_subscription_rejected(self) -> None: + query = "subscription { cpuSubscription { query data } }" + with pytest.raises(ToolError, match="must be a subscription"): + _validate_subscription_query(query) + + def test_mutation_case_insensitive_rejection(self) -> None: + query = 'MUTATION { docker { start(id: "abc") } }' + with pytest.raises(ToolError, match="must be a subscription"): + _validate_subscription_query(query) + + def test_mutation_field_identifier_not_rejected(self) -> None: + """'mutationField' as an identifier must NOT be rejected — only standalone 'mutation'.""" + # This tests the \b word boundary in _FORBIDDEN_KEYWORDS + query = "subscription { cpuSubscription { mutationField } }" + # Should not raise — "mutationField" is an identifier, not the keyword + result = _validate_subscription_query(query) + assert result == "cpuSubscription" + + def test_query_field_identifier_not_rejected(self) -> None: + """'queryResult' as an identifier must NOT be rejected.""" + query = "subscription { cpuSubscription { queryResult } }" + result = _validate_subscription_query(query) + assert result == "cpuSubscription" + + +class TestValidateSubscriptionQueryInvalidFormat: + """Queries that don't match the expected subscription format must be rejected.""" + + def test_empty_string_rejected(self) -> None: + with pytest.raises(ToolError, match="must start with 'subscription'"): + _validate_subscription_query("") + + def test_plain_identifier_rejected(self) -> None: + with pytest.raises(ToolError, match="must start with 'subscription'"): + _validate_subscription_query("cpuSubscription { usage }") + + def test_missing_operation_body_rejected(self) -> None: + with pytest.raises(ToolError, match="must start with 'subscription'"): + _validate_subscription_query("subscription") + + def test_subscription_without_field_rejected(self) -> None: + """subscription { } with no field name doesn't match the pattern.""" + with pytest.raises(ToolError, match="must start with 'subscription'"): + _validate_subscription_query("subscription { }") + + +class TestValidateSubscriptionQueryUnknownName: + """Subscription names not in the whitelist must be rejected even if format is valid.""" + + def test_unknown_subscription_name_rejected(self) -> None: + query = "subscription { unknownSubscription { data } }" + with pytest.raises(ToolError, match="not allowed"): + _validate_subscription_query(query) + + def test_error_message_includes_allowed_list(self) -> None: + """Error message must list the allowed subscription names for usability.""" + query = "subscription { badSub { data } }" + with pytest.raises(ToolError, match="Allowed subscriptions"): + _validate_subscription_query(query) + + def test_arbitrary_field_name_rejected(self) -> None: + query = "subscription { users { id email } }" + with pytest.raises(ToolError, match="not allowed"): + _validate_subscription_query(query) + + def test_close_but_not_whitelisted_rejected(self) -> None: + """'cpu' without 'Subscription' suffix is not in the allow-list.""" + query = "subscription { cpu { usage } }" + with pytest.raises(ToolError, match="not allowed"): + _validate_subscription_query(query) diff --git a/unraid_mcp/core/client.py b/unraid_mcp/core/client.py index 9c6369b..ea568cf 100644 --- a/unraid_mcp/core/client.py +++ b/unraid_mcp/core/client.py @@ -8,7 +8,7 @@ import asyncio import hashlib import json import time -from typing import Any +from typing import Any, Final import httpx @@ -23,20 +23,22 @@ from ..config.settings import ( from ..core.exceptions import ToolError -# Sensitive keys to redact from debug logs -_SENSITIVE_KEYS = { - "password", - "key", - "secret", - "token", - "apikey", - "authorization", - "cookie", - "session", - "credential", - "passphrase", - "jwt", -} +# Sensitive keys to redact from debug logs (frozenset — immutable, Final — no accidental reassignment) +_SENSITIVE_KEYS: Final[frozenset[str]] = frozenset( + { + "password", + "key", + "secret", + "token", + "apikey", + "authorization", + "cookie", + "session", + "credential", + "passphrase", + "jwt", + } +) def _is_sensitive_key(key: str) -> bool: @@ -80,16 +82,9 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout: # Global connection pool (module-level singleton) +# Python 3.12+ asyncio.Lock() is safe at module level — no running event loop required _http_client: httpx.AsyncClient | None = None -_client_lock: asyncio.Lock | None = None - - -def _get_client_lock() -> asyncio.Lock: - """Get or create the client lock (lazy init to avoid event loop issues).""" - global _client_lock - if _client_lock is None: - _client_lock = asyncio.Lock() - return _client_lock +_client_lock: Final[asyncio.Lock] = asyncio.Lock() class _RateLimiter: @@ -103,12 +98,8 @@ class _RateLimiter: self.tokens = float(max_tokens) self.refill_rate = refill_rate # tokens per second self.last_refill = time.monotonic() - self._lock: asyncio.Lock | None = None - - def _get_lock(self) -> asyncio.Lock: - if self._lock is None: - self._lock = asyncio.Lock() - return self._lock + # asyncio.Lock() is safe to create at __init__ time (Python 3.12+) + self._lock: Final[asyncio.Lock] = asyncio.Lock() def _refill(self) -> None: """Refill tokens based on elapsed time.""" @@ -120,7 +111,7 @@ class _RateLimiter: async def acquire(self) -> None: """Consume one token, waiting if necessary for refill.""" while True: - async with self._get_lock(): + async with self._lock: self._refill() if self.tokens >= 1: self.tokens -= 1 @@ -266,7 +257,7 @@ async def get_http_client() -> httpx.AsyncClient: return client # Slow-path: acquire lock for initialization - async with _get_client_lock(): + async with _client_lock: if _http_client is None or _http_client.is_closed: _http_client = await _create_http_client() logger.info( @@ -279,7 +270,7 @@ async def close_http_client() -> None: """Close the shared HTTP client (call on server shutdown).""" global _http_client - async with _get_client_lock(): + async with _client_lock: if _http_client is not None: await _http_client.aclose() _http_client = None @@ -361,6 +352,14 @@ async def make_graphql_request( if response is None: # pragma: no cover — guaranteed by loop raise ToolError("No response received after retry attempts") + + # Provide a clear message when all retries are exhausted on 429 + if response.status_code == 429: + logger.error("Rate limit (429) persisted after 3 retries — request aborted") + raise ToolError( + "Unraid API is rate limiting requests. Wait ~10 seconds before retrying." + ) + response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx response_data = response.json() diff --git a/unraid_mcp/core/exceptions.py b/unraid_mcp/core/exceptions.py index c5b99cf..34d84d1 100644 --- a/unraid_mcp/core/exceptions.py +++ b/unraid_mcp/core/exceptions.py @@ -6,7 +6,7 @@ throughout the application, with proper integration to FastMCP's error system. import contextlib import logging -from collections.abc import Generator +from collections.abc import Iterator from fastmcp.exceptions import ToolError as FastMCPToolError @@ -28,11 +28,12 @@ def tool_error_handler( tool_name: str, action: str, logger: logging.Logger, -) -> Generator[None]: +) -> Iterator[None]: """Context manager that standardizes tool error handling. - Re-raises ToolError as-is. Catches all other exceptions, logs them - with full traceback, and wraps them in ToolError with a descriptive message. + Re-raises ToolError as-is. Gives TimeoutError a descriptive message. + Catches all other exceptions, logs them with full traceback, and wraps them + in ToolError with a descriptive message. Args: tool_name: The tool name for error messages (e.g., "docker", "vm"). @@ -43,6 +44,14 @@ def tool_error_handler( yield except ToolError: raise + except TimeoutError as e: + logger.error( + f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit", + exc_info=True, + ) + raise ToolError( + f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time." + ) from e except Exception as e: logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True) raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e diff --git a/unraid_mcp/subscriptions/diagnostics.py b/unraid_mcp/subscriptions/diagnostics.py index 88da6e8..f72d010 100644 --- a/unraid_mcp/subscriptions/diagnostics.py +++ b/unraid_mcp/subscriptions/diagnostics.py @@ -37,8 +37,10 @@ _ALLOWED_SUBSCRIPTION_NAMES = frozenset( } ) -# Pattern: must start with "subscription", contain only a known subscription name, -# and not contain mutation/query keywords or semicolons (prevents injection). +# Pattern: must start with "subscription" and contain only a known subscription name. +# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query" +# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are +# not rejected — only bare "mutation" or "query" operation keywords are blocked. _SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE) _FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE) diff --git a/unraid_mcp/subscriptions/manager.py b/unraid_mcp/subscriptions/manager.py index 75b948d..0416e2e 100644 --- a/unraid_mcp/subscriptions/manager.py +++ b/unraid_mcp/subscriptions/manager.py @@ -32,12 +32,17 @@ _STABLE_CONNECTION_SECONDS = 30 def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: """Cap log content in subscription data to prevent unbounded memory growth. - If the data contains a 'content' field (from log subscriptions) that exceeds - size limits, truncate to the most recent _MAX_RESOURCE_DATA_LINES lines. + Returns a new dict — does NOT mutate the input. If any nested 'content' + field (from log subscriptions) exceeds the byte limit, truncate it to the + most recent _MAX_RESOURCE_DATA_LINES lines. + + Note: single lines larger than _MAX_RESOURCE_DATA_BYTES are not split and + will still be stored at full size; only multi-line content is truncated. """ + result: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, dict): - data[key] = _cap_log_content(value) + result[key] = _cap_log_content(value) elif ( key == "content" and isinstance(value, str) @@ -50,8 +55,12 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: f"[RESOURCE] Capped log content from {len(lines)} to " f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)" ) - data[key] = truncated - return data + result[key] = truncated + else: + result[key] = value + else: + result[key] = value + return result class SubscriptionManager: @@ -355,11 +364,13 @@ class SubscriptionManager: if isinstance(payload["data"], dict) else payload["data"] ) - self.resource_data[subscription_name] = SubscriptionData( + new_entry = SubscriptionData( data=capped_data, last_updated=datetime.now(UTC), subscription_type=subscription_name, ) + async with self.subscription_lock: + self.resource_data[subscription_name] = new_entry logger.debug( f"[RESOURCE:{subscription_name}] Resource data updated successfully" ) @@ -484,6 +495,16 @@ class SubscriptionManager: self.connection_states[subscription_name] = "reconnecting" await asyncio.sleep(retry_delay) + # The while loop exited (via break or max_retries exceeded). + # Remove from active_subscriptions so start_subscription() can restart it. + async with self.subscription_lock: + self.active_subscriptions.pop(subscription_name, None) + logger.info( + f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — " + f"removed from active_subscriptions. Final state: " + f"{self.connection_states.get(subscription_name, 'unknown')}" + ) + async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None: """Get current resource data with enhanced logging.""" logger.debug(f"[RESOURCE:{resource_name}] Resource data requested") diff --git a/unraid_mcp/tools/docker.py b/unraid_mcp/tools/docker.py index 0568f64..cd31e7b 100644 --- a/unraid_mcp/tools/docker.py +++ b/unraid_mcp/tools/docker.py @@ -99,8 +99,7 @@ MUTATIONS: dict[str, str] = { """, } -DESTRUCTIVE_ACTIONS = {"remove"} -_MUTATION_ACTIONS = {"start", "stop", "restart", "pause", "unpause", "remove", "update"} +DESTRUCTIVE_ACTIONS = {"remove", "update_all"} # NOTE (Code-M-07): "details" and "logs" are listed here because they require a # container_id parameter, but unlike mutations they use fuzzy name matching (not # strict). This is intentional: read-only queries are safe with fuzzy matching. diff --git a/unraid_mcp/tools/health.py b/unraid_mcp/tools/health.py index f378e6d..dc34559 100644 --- a/unraid_mcp/tools/health.py +++ b/unraid_mcp/tools/health.py @@ -37,8 +37,8 @@ def _safe_display_url(url: str | None) -> str | None: if parsed.port: return f"{parsed.scheme}://{host}:{parsed.port}" return f"{parsed.scheme}://{host}" - except Exception: - # If parsing fails, show nothing rather than leaking the raw URL + except ValueError: + # urlparse raises ValueError for invalid URLs (e.g. contains control chars) return "" @@ -235,9 +235,9 @@ def _analyze_subscription_status( """Analyze subscription status dict, returning error count and connection issues. This is the canonical implementation of subscription status analysis. - TODO: subscriptions/diagnostics.py (lines 168-182) duplicates this logic. - That module should be refactored to call this helper once file ownership - allows cross-agent edits. See Code-H05. + TODO: subscriptions/diagnostics.py has a similar status-analysis pattern + in diagnose_subscriptions(). That module could import and call this helper + directly to avoid divergence. See Code-H05. Args: status: Dict of subscription name -> status info from get_subscription_status().