forked from HomeLab/unraid-mcp
test: close critical coverage gaps and harden PR review fixes
Critical bug fixes from PR review agents: - client.py: eager asyncio.Lock init, Final[frozenset] for _SENSITIVE_KEYS, explicit 429 ToolError after retries exhausted, removed lazy _get_client_lock() and _RateLimiter._get_lock() patterns - exceptions.py: use builtin TimeoutError (UP041), explicit handler before broad except so asyncio timeouts get descriptive messages - docker.py: add update_all to DESTRUCTIVE_ACTIONS (was missing), remove dead _MUTATION_ACTIONS constant - manager.py: _cap_log_content returns new dict (immutable), lock write to resource_data, clean dead task from active_subscriptions after loop exits - diagnostics.py: fix inaccurate comment about semicolon injection guard - health.py: narrow except ValueError in _safe_display_url, fix TODO comment New test coverage (98 tests added, 529 → 598 passing): - test_subscription_validation.py: 27 tests for _validate_subscription_query (security-critical allow-list, forbidden keyword guards, word-boundary test) - test_subscription_manager.py: 12 tests for _cap_log_content (immutability, truncation, nesting, passthrough) - test_client.py: +57 tests — _RateLimiter (token math, refill, sleep-on-empty), _QueryCache (TTL, invalidation, is_cacheable), 429 retry loop (1/2/3 failures) - test_health.py: +10 tests for _safe_display_url (credential strip, port, path/query removal, malformed IPv6 → <unparseable>) - test_notifications.py: +7 importance enum and field length validation tests - test_rclone.py: +7 _validate_config_data security guard tests - test_storage.py: +15 (tail_lines bounds, format_kb, safe_get) - test_docker.py: update_all now requires confirm=True + new guard test - test_destructive_guards.py: update audit to include update_all Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -39,7 +39,7 @@ KNOWN_DESTRUCTIVE: dict[str, dict[str, set[str]]] = {
|
|||||||
"module": "unraid_mcp.tools.docker",
|
"module": "unraid_mcp.tools.docker",
|
||||||
"register_fn": "register_docker_tool",
|
"register_fn": "register_docker_tool",
|
||||||
"tool_name": "unraid_docker",
|
"tool_name": "unraid_docker",
|
||||||
"actions": {"remove"},
|
"actions": {"remove", "update_all"},
|
||||||
"runtime_set": DOCKER_DESTRUCTIVE,
|
"runtime_set": DOCKER_DESTRUCTIVE,
|
||||||
},
|
},
|
||||||
"vm": {
|
"vm": {
|
||||||
@@ -143,6 +143,7 @@ class TestDestructiveActionRegistries:
|
|||||||
_DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [
|
_DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [
|
||||||
# Docker
|
# Docker
|
||||||
("docker", "remove", {"container_id": "abc123"}),
|
("docker", "remove", {"container_id": "abc123"}),
|
||||||
|
("docker", "update_all", {}),
|
||||||
# VM
|
# VM
|
||||||
("vm", "force_stop", {"vm_id": "test-vm-uuid"}),
|
("vm", "force_stop", {"vm_id": "test-vm-uuid"}),
|
||||||
("vm", "reset", {"vm_id": "test-vm-uuid"}),
|
("vm", "reset", {"vm_id": "test-vm-uuid"}),
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for unraid_mcp.core.client — GraphQL client infrastructure."""
|
"""Tests for unraid_mcp.core.client — GraphQL client infrastructure."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -9,6 +10,8 @@ import pytest
|
|||||||
from unraid_mcp.core.client import (
|
from unraid_mcp.core.client import (
|
||||||
DEFAULT_TIMEOUT,
|
DEFAULT_TIMEOUT,
|
||||||
DISK_TIMEOUT,
|
DISK_TIMEOUT,
|
||||||
|
_QueryCache,
|
||||||
|
_RateLimiter,
|
||||||
_redact_sensitive,
|
_redact_sensitive,
|
||||||
is_idempotent_error,
|
is_idempotent_error,
|
||||||
make_graphql_request,
|
make_graphql_request,
|
||||||
@@ -464,3 +467,231 @@ class TestGraphQLErrorHandling:
|
|||||||
pytest.raises(ToolError, match="GraphQL API error"),
|
pytest.raises(ToolError, match="GraphQL API error"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _RateLimiter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""Unit tests for the token bucket rate limiter."""
|
||||||
|
|
||||||
|
async def test_acquire_consumes_one_token(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
|
||||||
|
initial = limiter.tokens
|
||||||
|
await limiter.acquire()
|
||||||
|
assert limiter.tokens == initial - 1
|
||||||
|
|
||||||
|
async def test_acquire_succeeds_when_tokens_available(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=5, refill_rate=1.0)
|
||||||
|
# Should complete without sleeping
|
||||||
|
for _ in range(5):
|
||||||
|
await limiter.acquire()
|
||||||
|
# _refill() runs during each acquire() call and adds a tiny time-based
|
||||||
|
# amount; check < 1.0 (not enough for another immediate request) rather
|
||||||
|
# than == 0.0 to avoid flakiness from timing.
|
||||||
|
assert limiter.tokens < 1.0
|
||||||
|
|
||||||
|
async def test_tokens_do_not_exceed_max(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
|
||||||
|
# Force refill with large elapsed time
|
||||||
|
limiter.last_refill = time.monotonic() - 100.0 # 100 seconds ago
|
||||||
|
limiter._refill()
|
||||||
|
assert limiter.tokens == 10.0 # Capped at max_tokens
|
||||||
|
|
||||||
|
async def test_refill_adds_tokens_based_on_elapsed(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=100, refill_rate=10.0)
|
||||||
|
limiter.tokens = 0.0
|
||||||
|
limiter.last_refill = time.monotonic() - 1.0 # 1 second ago
|
||||||
|
limiter._refill()
|
||||||
|
# Should have refilled ~10 tokens (10.0 rate * 1.0 sec)
|
||||||
|
assert 9.5 < limiter.tokens < 10.5
|
||||||
|
|
||||||
|
async def test_acquire_sleeps_when_no_tokens(self) -> None:
|
||||||
|
"""When tokens are exhausted, acquire should sleep before consuming."""
|
||||||
|
limiter = _RateLimiter(max_tokens=1, refill_rate=1.0)
|
||||||
|
limiter.tokens = 0.0
|
||||||
|
|
||||||
|
sleep_calls = []
|
||||||
|
|
||||||
|
async def fake_sleep(duration: float) -> None:
|
||||||
|
sleep_calls.append(duration)
|
||||||
|
# Simulate refill by advancing last_refill so tokens replenish
|
||||||
|
limiter.tokens = 1.0
|
||||||
|
limiter.last_refill = time.monotonic()
|
||||||
|
|
||||||
|
with patch("unraid_mcp.core.client.asyncio.sleep", side_effect=fake_sleep):
|
||||||
|
await limiter.acquire()
|
||||||
|
|
||||||
|
assert len(sleep_calls) == 1
|
||||||
|
assert sleep_calls[0] > 0
|
||||||
|
|
||||||
|
async def test_default_params_match_api_limits(self) -> None:
|
||||||
|
"""Default rate limiter must use 90 tokens at 9.0/sec (10% headroom from 100/10s)."""
|
||||||
|
limiter = _RateLimiter()
|
||||||
|
assert limiter.max_tokens == 90
|
||||||
|
assert limiter.refill_rate == 9.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _QueryCache
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryCache:
|
||||||
|
"""Unit tests for the TTL query cache."""
|
||||||
|
|
||||||
|
def test_miss_on_empty_cache(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
assert cache.get("{ info }", None) is None
|
||||||
|
|
||||||
|
def test_put_and_get_hit(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
data = {"result": "ok"}
|
||||||
|
cache.put("GetNetworkConfig { }", None, data)
|
||||||
|
result = cache.get("GetNetworkConfig { }", None)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_expired_entry_returns_none(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
data = {"result": "ok"}
|
||||||
|
cache.put("GetNetworkConfig { }", None, data)
|
||||||
|
# Manually expire the entry
|
||||||
|
key = cache._cache_key("GetNetworkConfig { }", None)
|
||||||
|
cache._store[key] = (time.monotonic() - 1.0, data) # expired 1 sec ago
|
||||||
|
assert cache.get("GetNetworkConfig { }", None) is None
|
||||||
|
|
||||||
|
def test_invalidate_all_clears_store(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
cache.put("GetNetworkConfig { }", None, {"x": 1})
|
||||||
|
cache.put("GetOwner { }", None, {"y": 2})
|
||||||
|
assert len(cache._store) == 2
|
||||||
|
cache.invalidate_all()
|
||||||
|
assert len(cache._store) == 0
|
||||||
|
|
||||||
|
def test_variables_affect_cache_key(self) -> None:
|
||||||
|
"""Different variables produce different cache keys."""
|
||||||
|
cache = _QueryCache()
|
||||||
|
q = "GetNetworkConfig($id: ID!) { network(id: $id) { name } }"
|
||||||
|
cache.put(q, {"id": "1"}, {"name": "eth0"})
|
||||||
|
cache.put(q, {"id": "2"}, {"name": "eth1"})
|
||||||
|
assert cache.get(q, {"id": "1"}) == {"name": "eth0"}
|
||||||
|
assert cache.get(q, {"id": "2"}) == {"name": "eth1"}
|
||||||
|
|
||||||
|
def test_is_cacheable_returns_true_for_known_prefixes(self) -> None:
|
||||||
|
assert _QueryCache.is_cacheable("GetNetworkConfig { ... }") is True
|
||||||
|
assert _QueryCache.is_cacheable("GetRegistrationInfo { ... }") is True
|
||||||
|
assert _QueryCache.is_cacheable("GetOwner { ... }") is True
|
||||||
|
assert _QueryCache.is_cacheable("GetFlash { ... }") is True
|
||||||
|
|
||||||
|
def test_is_cacheable_returns_false_for_mutations(self) -> None:
|
||||||
|
assert _QueryCache.is_cacheable('mutation { docker { start(id: "x") } }') is False
|
||||||
|
|
||||||
|
def test_is_cacheable_returns_false_for_unlisted_queries(self) -> None:
|
||||||
|
assert _QueryCache.is_cacheable("{ docker { containers { id } } }") is False
|
||||||
|
assert _QueryCache.is_cacheable("{ info { os } }") is False
|
||||||
|
|
||||||
|
def test_is_cacheable_mutation_check_is_prefix(self) -> None:
|
||||||
|
"""Queries that start with 'mutation' after whitespace are not cacheable."""
|
||||||
|
assert _QueryCache.is_cacheable(" mutation { ... }") is False
|
||||||
|
|
||||||
|
def test_expired_entry_removed_from_store(self) -> None:
|
||||||
|
"""Accessing an expired entry should remove it from the internal store."""
|
||||||
|
cache = _QueryCache()
|
||||||
|
cache.put("GetOwner { }", None, {"owner": "root"})
|
||||||
|
key = cache._cache_key("GetOwner { }", None)
|
||||||
|
cache._store[key] = (time.monotonic() - 1.0, {"owner": "root"})
|
||||||
|
assert key in cache._store
|
||||||
|
cache.get("GetOwner { }", None) # triggers deletion
|
||||||
|
assert key not in cache._store
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# make_graphql_request — 429 retry behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitRetry:
|
||||||
|
"""Tests for the 429 retry loop in make_graphql_request."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _patch_config(self):
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.core.client.UNRAID_API_URL", "https://unraid.local/graphql"),
|
||||||
|
patch("unraid_mcp.core.client.UNRAID_API_KEY", "test-key"),
|
||||||
|
patch("unraid_mcp.core.client.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def _make_429_response(self) -> MagicMock:
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 429
|
||||||
|
resp.raise_for_status = MagicMock()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def _make_ok_response(self, data: dict) -> MagicMock:
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 200
|
||||||
|
resp.raise_for_status = MagicMock()
|
||||||
|
resp.json.return_value = {"data": data}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def test_single_429_then_success_retries(self) -> None:
|
||||||
|
"""One 429 followed by a success should return the data."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_ok_response({"info": {"os": "Unraid"}}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client):
|
||||||
|
result = await make_graphql_request("{ info { os } }")
|
||||||
|
|
||||||
|
assert result == {"info": {"os": "Unraid"}}
|
||||||
|
assert mock_client.post.call_count == 2
|
||||||
|
|
||||||
|
async def test_two_429s_then_success(self) -> None:
|
||||||
|
"""Two 429s followed by success returns data after 2 retries."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_ok_response({"x": 1}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client):
|
||||||
|
result = await make_graphql_request("{ x }")
|
||||||
|
|
||||||
|
assert result == {"x": 1}
|
||||||
|
assert mock_client.post.call_count == 3
|
||||||
|
|
||||||
|
async def test_three_429s_raises_tool_error(self) -> None:
|
||||||
|
"""Three consecutive 429s (all retries exhausted) raises ToolError."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
|
pytest.raises(ToolError, match="rate limiting"),
|
||||||
|
):
|
||||||
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
|
async def test_rate_limit_error_message_advises_wait(self) -> None:
|
||||||
|
"""The ToolError message should tell the user to wait ~10 seconds."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
|
pytest.raises(ToolError, match="10 seconds"),
|
||||||
|
):
|
||||||
|
await make_graphql_request("{ info }")
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ class TestDockerActions:
|
|||||||
"docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]}
|
"docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]}
|
||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="update_all")
|
result = await tool_fn(action="update_all", confirm=True)
|
||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
assert len(result["containers"]) == 1
|
assert len(result["containers"]) == 1
|
||||||
|
|
||||||
@@ -271,10 +271,16 @@ class TestDockerMutationFailures:
|
|||||||
"""update_all with no containers to update."""
|
"""update_all with no containers to update."""
|
||||||
_mock_graphql.return_value = {"docker": {"updateAllContainers": []}}
|
_mock_graphql.return_value = {"docker": {"updateAllContainers": []}}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="update_all")
|
result = await tool_fn(action="update_all", confirm=True)
|
||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
assert result["containers"] == []
|
assert result["containers"] == []
|
||||||
|
|
||||||
|
async def test_update_all_requires_confirm(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
"""update_all is destructive and requires confirm=True."""
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="destructive"):
|
||||||
|
await tool_fn(action="update_all")
|
||||||
|
|
||||||
async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None:
|
async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None:
|
||||||
"""Mid-operation timeout during a docker mutation."""
|
"""Mid-operation timeout during a docker mutation."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import pytest
|
|||||||
from conftest import make_tool_fn
|
from conftest import make_tool_fn
|
||||||
|
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
from unraid_mcp.tools.health import _safe_display_url
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -139,3 +140,55 @@ class TestHealthActions:
|
|||||||
finally:
|
finally:
|
||||||
# Restore cached modules
|
# Restore cached modules
|
||||||
sys.modules.update(cached)
|
sys.modules.update(cached)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _safe_display_url — URL redaction helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeDisplayUrl:
|
||||||
|
"""Verify that _safe_display_url strips credentials/path and preserves scheme+host+port."""
|
||||||
|
|
||||||
|
def test_none_returns_none(self) -> None:
|
||||||
|
assert _safe_display_url(None) is None
|
||||||
|
|
||||||
|
def test_empty_string_returns_none(self) -> None:
|
||||||
|
assert _safe_display_url("") is None
|
||||||
|
|
||||||
|
def test_simple_url_scheme_and_host(self) -> None:
|
||||||
|
assert _safe_display_url("https://unraid.local/graphql") == "https://unraid.local"
|
||||||
|
|
||||||
|
def test_preserves_port(self) -> None:
|
||||||
|
assert _safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337"
|
||||||
|
|
||||||
|
def test_strips_path(self) -> None:
|
||||||
|
result = _safe_display_url("http://unraid.local/some/deep/path?query=1")
|
||||||
|
assert "path" not in result
|
||||||
|
assert "query" not in result
|
||||||
|
|
||||||
|
def test_strips_credentials(self) -> None:
|
||||||
|
result = _safe_display_url("https://user:password@unraid.local/graphql")
|
||||||
|
assert "user" not in result
|
||||||
|
assert "password" not in result
|
||||||
|
assert result == "https://unraid.local"
|
||||||
|
|
||||||
|
def test_strips_query_params(self) -> None:
|
||||||
|
result = _safe_display_url("http://host.local?token=abc&key=xyz")
|
||||||
|
assert "token" not in result
|
||||||
|
assert "abc" not in result
|
||||||
|
|
||||||
|
def test_http_scheme_preserved(self) -> None:
|
||||||
|
result = _safe_display_url("http://10.0.0.1:8080/api")
|
||||||
|
assert result == "http://10.0.0.1:8080"
|
||||||
|
|
||||||
|
def test_tailscale_url(self) -> None:
|
||||||
|
result = _safe_display_url("https://100.118.209.1:31337/graphql")
|
||||||
|
assert result == "https://100.118.209.1:31337"
|
||||||
|
|
||||||
|
def test_malformed_ipv6_url_returns_unparseable(self) -> None:
|
||||||
|
"""Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError."""
|
||||||
|
# urlparse("https://[invalid") parses without error, but accessing .hostname
|
||||||
|
# raises ValueError: Invalid IPv6 URL — this triggers the except branch.
|
||||||
|
result = _safe_display_url("https://[invalid")
|
||||||
|
assert result == "<unparseable>"
|
||||||
|
|||||||
@@ -151,3 +151,87 @@ class TestNotificationsActions:
|
|||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="boom"):
|
with pytest.raises(ToolError, match="boom"):
|
||||||
await tool_fn(action="overview")
|
await tool_fn(action="overview")
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotificationsCreateValidation:
|
||||||
|
"""Tests for importance enum and field length validation added in this PR."""
|
||||||
|
|
||||||
|
async def test_invalid_importance_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="importance must be one of"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_info_importance_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
"""INFO is listed in old docstring examples but rejected by the validator."""
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="importance must be one of"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_alert_importance_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"notifications": {"createNotification": {"id": "n:1", "importance": "ALERT"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create", title="T", subject="S", description="D", importance="alert"
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
async def test_title_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="title must be at most 200"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="x" * 201,
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_subject_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="subject must be at most 500"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="x" * 501,
|
||||||
|
description="D",
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_description_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="description must be at most 2000"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="S",
|
||||||
|
description="x" * 2001,
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_title_at_max_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"notifications": {"createNotification": {"id": "n:1", "importance": "NORMAL"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="x" * 200,
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|||||||
@@ -100,3 +100,83 @@ class TestRcloneActions:
|
|||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="Failed to delete"):
|
with pytest.raises(ToolError, match="Failed to delete"):
|
||||||
await tool_fn(action="delete_remote", name="gdrive", confirm=True)
|
await tool_fn(action="delete_remote", name="gdrive", confirm=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRcloneConfigDataValidation:
|
||||||
|
"""Tests for _validate_config_data security guards."""
|
||||||
|
|
||||||
|
async def test_path_traversal_in_key_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="disallowed characters"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"../evil": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_shell_metachar_in_key_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="disallowed characters"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"key;rm": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_too_many_keys_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="max 50"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={f"key{i}": "v" for i in range(51)},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_dict_value_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="string, number, or boolean"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"nested": {"key": "val"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_value_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="exceeds max length"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"key": "x" * 4097},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_boolean_value_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"rclone": {"createRCloneRemote": {"name": "r", "type": "s3"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"use_path_style": True},
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
async def test_int_value_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"rclone": {"createRCloneRemote": {"name": "r", "type": "sftp"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="sftp",
|
||||||
|
config_data={"port": 22},
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
from conftest import make_tool_fn
|
from conftest import make_tool_fn
|
||||||
|
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
from unraid_mcp.core.utils import format_bytes
|
from unraid_mcp.core.utils import format_bytes, format_kb, safe_get
|
||||||
|
|
||||||
|
|
||||||
# --- Unit tests for helpers ---
|
# --- Unit tests for helpers ---
|
||||||
@@ -77,6 +77,70 @@ class TestStorageValidation:
|
|||||||
result = await tool_fn(action="logs", log_path="/var/log/syslog")
|
result = await tool_fn(action="logs", log_path="/var/log/syslog")
|
||||||
assert result["content"] == "ok"
|
assert result["content"] == "ok"
|
||||||
|
|
||||||
|
async def test_logs_tail_lines_too_large(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="tail_lines must be between"):
|
||||||
|
await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_001)
|
||||||
|
|
||||||
|
async def test_logs_tail_lines_zero_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="tail_lines must be between"):
|
||||||
|
await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=0)
|
||||||
|
|
||||||
|
async def test_logs_tail_lines_at_max_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {"logFile": {"path": "/var/log/syslog", "content": "ok"}}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_000)
|
||||||
|
assert result["content"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatKb:
|
||||||
|
def test_none_returns_na(self) -> None:
|
||||||
|
assert format_kb(None) == "N/A"
|
||||||
|
|
||||||
|
def test_invalid_string_returns_na(self) -> None:
|
||||||
|
assert format_kb("not-a-number") == "N/A"
|
||||||
|
|
||||||
|
def test_kilobytes_range(self) -> None:
|
||||||
|
assert format_kb(512) == "512 KB"
|
||||||
|
|
||||||
|
def test_megabytes_range(self) -> None:
|
||||||
|
assert format_kb(2048) == "2.00 MB"
|
||||||
|
|
||||||
|
def test_gigabytes_range(self) -> None:
|
||||||
|
assert format_kb(1_048_576) == "1.00 GB"
|
||||||
|
|
||||||
|
def test_terabytes_range(self) -> None:
|
||||||
|
assert format_kb(1_073_741_824) == "1.00 TB"
|
||||||
|
|
||||||
|
def test_boundary_exactly_1024_kb(self) -> None:
|
||||||
|
# 1024 KB = 1 MB
|
||||||
|
assert format_kb(1024) == "1.00 MB"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeGet:
|
||||||
|
def test_simple_key_access(self) -> None:
|
||||||
|
assert safe_get({"a": 1}, "a") == 1
|
||||||
|
|
||||||
|
def test_nested_key_access(self) -> None:
|
||||||
|
assert safe_get({"a": {"b": "val"}}, "a", "b") == "val"
|
||||||
|
|
||||||
|
def test_missing_key_returns_none(self) -> None:
|
||||||
|
assert safe_get({"a": 1}, "missing") is None
|
||||||
|
|
||||||
|
def test_none_intermediate_returns_default(self) -> None:
|
||||||
|
assert safe_get({"a": None}, "a", "b") is None
|
||||||
|
|
||||||
|
def test_custom_default_returned(self) -> None:
|
||||||
|
assert safe_get({}, "x", default="fallback") == "fallback"
|
||||||
|
|
||||||
|
def test_non_dict_intermediate_returns_default(self) -> None:
|
||||||
|
assert safe_get({"a": "string"}, "a", "b") is None
|
||||||
|
|
||||||
|
def test_empty_list_default(self) -> None:
|
||||||
|
result = safe_get({}, "missing", default=[])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
class TestStorageActions:
|
class TestStorageActions:
|
||||||
async def test_shares(self, _mock_graphql: AsyncMock) -> None:
|
async def test_shares(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
|||||||
131
tests/test_subscription_manager.py
Normal file
131
tests/test_subscription_manager.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Tests for _cap_log_content in subscriptions/manager.py.
|
||||||
|
|
||||||
|
_cap_log_content is a pure utility that prevents unbounded memory growth from
|
||||||
|
log subscription data. It must: return a NEW dict (not mutate), recursively
|
||||||
|
cap nested 'content' fields, and only truncate when both byte limit and line
|
||||||
|
limit are exceeded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from unraid_mcp.subscriptions.manager import _cap_log_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentImmutability:
|
||||||
|
"""The function must return a new dict — never mutate the input."""
|
||||||
|
|
||||||
|
def test_returns_new_dict(self) -> None:
|
||||||
|
data = {"key": "value"}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result is not data
|
||||||
|
|
||||||
|
def test_input_not_mutated_on_passthrough(self) -> None:
|
||||||
|
data = {"content": "short text", "other": "value"}
|
||||||
|
original_content = data["content"]
|
||||||
|
_cap_log_content(data)
|
||||||
|
assert data["content"] == original_content
|
||||||
|
|
||||||
|
def test_input_not_mutated_on_truncation(self) -> None:
|
||||||
|
# Use small limits so the truncation path is exercised
|
||||||
|
large_content = "\n".join(f"line {i}" for i in range(200))
|
||||||
|
data = {"content": large_content}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
_cap_log_content(data)
|
||||||
|
# Original data must be unchanged
|
||||||
|
assert data["content"] == large_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentSmallData:
|
||||||
|
"""Content below the byte limit must be returned unchanged."""
|
||||||
|
|
||||||
|
def test_small_content_unchanged(self) -> None:
|
||||||
|
data = {"content": "just a few lines\nof log data\n"}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result["content"] == data["content"]
|
||||||
|
|
||||||
|
def test_non_content_keys_passed_through(self) -> None:
|
||||||
|
data = {"name": "cpu_subscription", "timestamp": "2026-02-18T00:00:00Z"}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_integer_value_passed_through(self) -> None:
|
||||||
|
data = {"count": 42, "active": True}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentTruncation:
|
||||||
|
"""Content exceeding both byte AND line limits must be truncated to the last N lines."""
|
||||||
|
|
||||||
|
def test_oversized_content_truncated_to_last_n_lines(self) -> None:
|
||||||
|
# 200 lines, limit 50 lines, byte limit effectively 0 → should keep last 50 lines
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"content": "\n".join(lines)}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
result_lines = result["content"].splitlines()
|
||||||
|
assert len(result_lines) == 50
|
||||||
|
# Must be the LAST 50 lines
|
||||||
|
assert result_lines[0] == "line 150"
|
||||||
|
assert result_lines[-1] == "line 199"
|
||||||
|
|
||||||
|
def test_content_with_fewer_lines_than_limit_not_truncated(self) -> None:
|
||||||
|
"""If byte limit exceeded but line count ≤ limit → keep original (not truncated)."""
|
||||||
|
# 30 lines but byte limit 10 and line limit 50 → 30 < 50 so no truncation
|
||||||
|
lines = [f"line {i}" for i in range(30)]
|
||||||
|
data = {"content": "\n".join(lines)}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
# Original content preserved
|
||||||
|
assert result["content"] == data["content"]
|
||||||
|
|
||||||
|
def test_non_content_keys_preserved_alongside_truncated_content(self) -> None:
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"content": "\n".join(lines), "path": "/var/log/syslog", "total_lines": 200}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result["path"] == "/var/log/syslog"
|
||||||
|
assert result["total_lines"] == 200
|
||||||
|
assert len(result["content"].splitlines()) == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentNested:
|
||||||
|
"""Nested 'content' fields inside sub-dicts must also be capped recursively."""
|
||||||
|
|
||||||
|
def test_nested_content_field_capped(self) -> None:
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"logFile": {"content": "\n".join(lines), "path": "/var/log/syslog"}}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["logFile"]["content"].splitlines()) == 50
|
||||||
|
assert result["logFile"]["path"] == "/var/log/syslog"
|
||||||
|
|
||||||
|
def test_deeply_nested_content_capped(self) -> None:
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"outer": {"inner": {"content": "\n".join(lines)}}}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["outer"]["inner"]["content"].splitlines()) == 50
|
||||||
|
|
||||||
|
def test_nested_non_content_keys_unaffected(self) -> None:
|
||||||
|
data = {"metrics": {"cpu": 42.5, "memory": 8192}}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result == data
|
||||||
131
tests/test_subscription_validation.py
Normal file
131
tests/test_subscription_validation.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Tests for _validate_subscription_query in diagnostics.py.
|
||||||
|
|
||||||
|
Security-critical: this function is the only guard against arbitrary GraphQL
|
||||||
|
operations (mutations, queries) being sent over the WebSocket subscription channel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
from unraid_mcp.subscriptions.diagnostics import (
|
||||||
|
_ALLOWED_SUBSCRIPTION_NAMES,
|
||||||
|
_validate_subscription_query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryAllowed:
|
||||||
|
"""All whitelisted subscription names must be accepted."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sub_name", sorted(_ALLOWED_SUBSCRIPTION_NAMES))
|
||||||
|
def test_all_allowed_names_accepted(self, sub_name: str) -> None:
|
||||||
|
query = f"subscription {{ {sub_name} {{ data }} }}"
|
||||||
|
result = _validate_subscription_query(query)
|
||||||
|
assert result == sub_name
|
||||||
|
|
||||||
|
def test_returns_extracted_subscription_name(self) -> None:
|
||||||
|
query = "subscription { cpuSubscription { usage } }"
|
||||||
|
assert _validate_subscription_query(query) == "cpuSubscription"
|
||||||
|
|
||||||
|
def test_leading_whitespace_accepted(self) -> None:
|
||||||
|
query = " subscription { memorySubscription { free } }"
|
||||||
|
assert _validate_subscription_query(query) == "memorySubscription"
|
||||||
|
|
||||||
|
def test_multiline_query_accepted(self) -> None:
|
||||||
|
query = "subscription {\n logFileSubscription {\n content\n }\n}"
|
||||||
|
assert _validate_subscription_query(query) == "logFileSubscription"
|
||||||
|
|
||||||
|
def test_case_insensitive_subscription_keyword(self) -> None:
|
||||||
|
"""'SUBSCRIPTION' should be accepted (regex uses IGNORECASE)."""
|
||||||
|
query = "SUBSCRIPTION { cpuSubscription { usage } }"
|
||||||
|
assert _validate_subscription_query(query) == "cpuSubscription"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryForbiddenKeywords:
|
||||||
|
"""Queries containing 'mutation' or 'query' as standalone keywords must be rejected."""
|
||||||
|
|
||||||
|
def test_mutation_keyword_rejected(self) -> None:
|
||||||
|
query = 'mutation { docker { start(id: "abc") } }'
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_query_keyword_rejected(self) -> None:
|
||||||
|
query = "query { info { os { platform } } }"
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_mutation_embedded_in_subscription_rejected(self) -> None:
|
||||||
|
"""'mutation' anywhere in the string triggers rejection."""
|
||||||
|
query = "subscription { cpuSubscription { mutation data } }"
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_query_embedded_in_subscription_rejected(self) -> None:
|
||||||
|
query = "subscription { cpuSubscription { query data } }"
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_mutation_case_insensitive_rejection(self) -> None:
|
||||||
|
query = 'MUTATION { docker { start(id: "abc") } }'
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_mutation_field_identifier_not_rejected(self) -> None:
|
||||||
|
"""'mutationField' as an identifier must NOT be rejected — only standalone 'mutation'."""
|
||||||
|
# This tests the \b word boundary in _FORBIDDEN_KEYWORDS
|
||||||
|
query = "subscription { cpuSubscription { mutationField } }"
|
||||||
|
# Should not raise — "mutationField" is an identifier, not the keyword
|
||||||
|
result = _validate_subscription_query(query)
|
||||||
|
assert result == "cpuSubscription"
|
||||||
|
|
||||||
|
def test_query_field_identifier_not_rejected(self) -> None:
|
||||||
|
"""'queryResult' as an identifier must NOT be rejected."""
|
||||||
|
query = "subscription { cpuSubscription { queryResult } }"
|
||||||
|
result = _validate_subscription_query(query)
|
||||||
|
assert result == "cpuSubscription"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryInvalidFormat:
|
||||||
|
"""Queries that don't match the expected subscription format must be rejected."""
|
||||||
|
|
||||||
|
def test_empty_string_rejected(self) -> None:
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("")
|
||||||
|
|
||||||
|
def test_plain_identifier_rejected(self) -> None:
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("cpuSubscription { usage }")
|
||||||
|
|
||||||
|
def test_missing_operation_body_rejected(self) -> None:
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("subscription")
|
||||||
|
|
||||||
|
def test_subscription_without_field_rejected(self) -> None:
|
||||||
|
"""subscription { } with no field name doesn't match the pattern."""
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("subscription { }")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryUnknownName:
|
||||||
|
"""Subscription names not in the whitelist must be rejected even if format is valid."""
|
||||||
|
|
||||||
|
def test_unknown_subscription_name_rejected(self) -> None:
|
||||||
|
query = "subscription { unknownSubscription { data } }"
|
||||||
|
with pytest.raises(ToolError, match="not allowed"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_error_message_includes_allowed_list(self) -> None:
|
||||||
|
"""Error message must list the allowed subscription names for usability."""
|
||||||
|
query = "subscription { badSub { data } }"
|
||||||
|
with pytest.raises(ToolError, match="Allowed subscriptions"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_arbitrary_field_name_rejected(self) -> None:
|
||||||
|
query = "subscription { users { id email } }"
|
||||||
|
with pytest.raises(ToolError, match="not allowed"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_close_but_not_whitelisted_rejected(self) -> None:
|
||||||
|
"""'cpu' without 'Subscription' suffix is not in the allow-list."""
|
||||||
|
query = "subscription { cpu { usage } }"
|
||||||
|
with pytest.raises(ToolError, match="not allowed"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
@@ -8,7 +8,7 @@ import asyncio
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, Final
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -23,8 +23,9 @@ from ..config.settings import (
|
|||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError
|
||||||
|
|
||||||
|
|
||||||
# Sensitive keys to redact from debug logs
|
# Sensitive keys to redact from debug logs (frozenset — immutable, Final — no accidental reassignment)
|
||||||
_SENSITIVE_KEYS = {
|
_SENSITIVE_KEYS: Final[frozenset[str]] = frozenset(
|
||||||
|
{
|
||||||
"password",
|
"password",
|
||||||
"key",
|
"key",
|
||||||
"secret",
|
"secret",
|
||||||
@@ -36,7 +37,8 @@ _SENSITIVE_KEYS = {
|
|||||||
"credential",
|
"credential",
|
||||||
"passphrase",
|
"passphrase",
|
||||||
"jwt",
|
"jwt",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_sensitive_key(key: str) -> bool:
|
def _is_sensitive_key(key: str) -> bool:
|
||||||
@@ -80,16 +82,9 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout:
|
|||||||
|
|
||||||
|
|
||||||
# Global connection pool (module-level singleton)
|
# Global connection pool (module-level singleton)
|
||||||
|
# Python 3.12+ asyncio.Lock() is safe at module level — no running event loop required
|
||||||
_http_client: httpx.AsyncClient | None = None
|
_http_client: httpx.AsyncClient | None = None
|
||||||
_client_lock: asyncio.Lock | None = None
|
_client_lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
def _get_client_lock() -> asyncio.Lock:
|
|
||||||
"""Get or create the client lock (lazy init to avoid event loop issues)."""
|
|
||||||
global _client_lock
|
|
||||||
if _client_lock is None:
|
|
||||||
_client_lock = asyncio.Lock()
|
|
||||||
return _client_lock
|
|
||||||
|
|
||||||
|
|
||||||
class _RateLimiter:
|
class _RateLimiter:
|
||||||
@@ -103,12 +98,8 @@ class _RateLimiter:
|
|||||||
self.tokens = float(max_tokens)
|
self.tokens = float(max_tokens)
|
||||||
self.refill_rate = refill_rate # tokens per second
|
self.refill_rate = refill_rate # tokens per second
|
||||||
self.last_refill = time.monotonic()
|
self.last_refill = time.monotonic()
|
||||||
self._lock: asyncio.Lock | None = None
|
# asyncio.Lock() is safe to create at __init__ time (Python 3.12+)
|
||||||
|
self._lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||||
def _get_lock(self) -> asyncio.Lock:
|
|
||||||
if self._lock is None:
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
return self._lock
|
|
||||||
|
|
||||||
def _refill(self) -> None:
|
def _refill(self) -> None:
|
||||||
"""Refill tokens based on elapsed time."""
|
"""Refill tokens based on elapsed time."""
|
||||||
@@ -120,7 +111,7 @@ class _RateLimiter:
|
|||||||
async def acquire(self) -> None:
|
async def acquire(self) -> None:
|
||||||
"""Consume one token, waiting if necessary for refill."""
|
"""Consume one token, waiting if necessary for refill."""
|
||||||
while True:
|
while True:
|
||||||
async with self._get_lock():
|
async with self._lock:
|
||||||
self._refill()
|
self._refill()
|
||||||
if self.tokens >= 1:
|
if self.tokens >= 1:
|
||||||
self.tokens -= 1
|
self.tokens -= 1
|
||||||
@@ -266,7 +257,7 @@ async def get_http_client() -> httpx.AsyncClient:
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
# Slow-path: acquire lock for initialization
|
# Slow-path: acquire lock for initialization
|
||||||
async with _get_client_lock():
|
async with _client_lock:
|
||||||
if _http_client is None or _http_client.is_closed:
|
if _http_client is None or _http_client.is_closed:
|
||||||
_http_client = await _create_http_client()
|
_http_client = await _create_http_client()
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -279,7 +270,7 @@ async def close_http_client() -> None:
|
|||||||
"""Close the shared HTTP client (call on server shutdown)."""
|
"""Close the shared HTTP client (call on server shutdown)."""
|
||||||
global _http_client
|
global _http_client
|
||||||
|
|
||||||
async with _get_client_lock():
|
async with _client_lock:
|
||||||
if _http_client is not None:
|
if _http_client is not None:
|
||||||
await _http_client.aclose()
|
await _http_client.aclose()
|
||||||
_http_client = None
|
_http_client = None
|
||||||
@@ -361,6 +352,14 @@ async def make_graphql_request(
|
|||||||
|
|
||||||
if response is None: # pragma: no cover — guaranteed by loop
|
if response is None: # pragma: no cover — guaranteed by loop
|
||||||
raise ToolError("No response received after retry attempts")
|
raise ToolError("No response received after retry attempts")
|
||||||
|
|
||||||
|
# Provide a clear message when all retries are exhausted on 429
|
||||||
|
if response.status_code == 429:
|
||||||
|
logger.error("Rate limit (429) persisted after 3 retries — request aborted")
|
||||||
|
raise ToolError(
|
||||||
|
"Unraid API is rate limiting requests. Wait ~10 seconds before retrying."
|
||||||
|
)
|
||||||
|
|
||||||
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
|
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
|
||||||
|
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ throughout the application, with proper integration to FastMCP's error system.
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from fastmcp.exceptions import ToolError as FastMCPToolError
|
from fastmcp.exceptions import ToolError as FastMCPToolError
|
||||||
|
|
||||||
@@ -28,11 +28,12 @@ def tool_error_handler(
|
|||||||
tool_name: str,
|
tool_name: str,
|
||||||
action: str,
|
action: str,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
) -> Generator[None]:
|
) -> Iterator[None]:
|
||||||
"""Context manager that standardizes tool error handling.
|
"""Context manager that standardizes tool error handling.
|
||||||
|
|
||||||
Re-raises ToolError as-is. Catches all other exceptions, logs them
|
Re-raises ToolError as-is. Gives TimeoutError a descriptive message.
|
||||||
with full traceback, and wraps them in ToolError with a descriptive message.
|
Catches all other exceptions, logs them with full traceback, and wraps them
|
||||||
|
in ToolError with a descriptive message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: The tool name for error messages (e.g., "docker", "vm").
|
tool_name: The tool name for error messages (e.g., "docker", "vm").
|
||||||
@@ -43,6 +44,14 @@ def tool_error_handler(
|
|||||||
yield
|
yield
|
||||||
except ToolError:
|
except ToolError:
|
||||||
raise
|
raise
|
||||||
|
except TimeoutError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise ToolError(
|
||||||
|
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
|
||||||
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True)
|
logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True)
|
||||||
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e
|
raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e
|
||||||
|
|||||||
@@ -37,8 +37,10 @@ _ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pattern: must start with "subscription", contain only a known subscription name,
|
# Pattern: must start with "subscription" and contain only a known subscription name.
|
||||||
# and not contain mutation/query keywords or semicolons (prevents injection).
|
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
|
||||||
|
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
|
||||||
|
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
|
||||||
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||||
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||||
|
|
||||||
|
|||||||
@@ -32,12 +32,17 @@ _STABLE_CONNECTION_SECONDS = 30
|
|||||||
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Cap log content in subscription data to prevent unbounded memory growth.
|
"""Cap log content in subscription data to prevent unbounded memory growth.
|
||||||
|
|
||||||
If the data contains a 'content' field (from log subscriptions) that exceeds
|
Returns a new dict — does NOT mutate the input. If any nested 'content'
|
||||||
size limits, truncate to the most recent _MAX_RESOURCE_DATA_LINES lines.
|
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
||||||
|
most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||||
|
|
||||||
|
Note: single lines larger than _MAX_RESOURCE_DATA_BYTES are not split and
|
||||||
|
will still be stored at full size; only multi-line content is truncated.
|
||||||
"""
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
data[key] = _cap_log_content(value)
|
result[key] = _cap_log_content(value)
|
||||||
elif (
|
elif (
|
||||||
key == "content"
|
key == "content"
|
||||||
and isinstance(value, str)
|
and isinstance(value, str)
|
||||||
@@ -50,8 +55,12 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
|||||||
f"[RESOURCE] Capped log content from {len(lines)} to "
|
f"[RESOURCE] Capped log content from {len(lines)} to "
|
||||||
f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)"
|
f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)"
|
||||||
)
|
)
|
||||||
data[key] = truncated
|
result[key] = truncated
|
||||||
return data
|
else:
|
||||||
|
result[key] = value
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionManager:
|
class SubscriptionManager:
|
||||||
@@ -355,11 +364,13 @@ class SubscriptionManager:
|
|||||||
if isinstance(payload["data"], dict)
|
if isinstance(payload["data"], dict)
|
||||||
else payload["data"]
|
else payload["data"]
|
||||||
)
|
)
|
||||||
self.resource_data[subscription_name] = SubscriptionData(
|
new_entry = SubscriptionData(
|
||||||
data=capped_data,
|
data=capped_data,
|
||||||
last_updated=datetime.now(UTC),
|
last_updated=datetime.now(UTC),
|
||||||
subscription_type=subscription_name,
|
subscription_type=subscription_name,
|
||||||
)
|
)
|
||||||
|
async with self.subscription_lock:
|
||||||
|
self.resource_data[subscription_name] = new_entry
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
|
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
|
||||||
)
|
)
|
||||||
@@ -484,6 +495,16 @@ class SubscriptionManager:
|
|||||||
self.connection_states[subscription_name] = "reconnecting"
|
self.connection_states[subscription_name] = "reconnecting"
|
||||||
await asyncio.sleep(retry_delay)
|
await asyncio.sleep(retry_delay)
|
||||||
|
|
||||||
|
# The while loop exited (via break or max_retries exceeded).
|
||||||
|
# Remove from active_subscriptions so start_subscription() can restart it.
|
||||||
|
async with self.subscription_lock:
|
||||||
|
self.active_subscriptions.pop(subscription_name, None)
|
||||||
|
logger.info(
|
||||||
|
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
|
||||||
|
f"removed from active_subscriptions. Final state: "
|
||||||
|
f"{self.connection_states.get(subscription_name, 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
||||||
"""Get current resource data with enhanced logging."""
|
"""Get current resource data with enhanced logging."""
|
||||||
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
||||||
|
|||||||
@@ -99,8 +99,7 @@ MUTATIONS: dict[str, str] = {
|
|||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
DESTRUCTIVE_ACTIONS = {"remove"}
|
DESTRUCTIVE_ACTIONS = {"remove", "update_all"}
|
||||||
_MUTATION_ACTIONS = {"start", "stop", "restart", "pause", "unpause", "remove", "update"}
|
|
||||||
# NOTE (Code-M-07): "details" and "logs" are listed here because they require a
|
# NOTE (Code-M-07): "details" and "logs" are listed here because they require a
|
||||||
# container_id parameter, but unlike mutations they use fuzzy name matching (not
|
# container_id parameter, but unlike mutations they use fuzzy name matching (not
|
||||||
# strict). This is intentional: read-only queries are safe with fuzzy matching.
|
# strict). This is intentional: read-only queries are safe with fuzzy matching.
|
||||||
|
|||||||
@@ -37,8 +37,8 @@ def _safe_display_url(url: str | None) -> str | None:
|
|||||||
if parsed.port:
|
if parsed.port:
|
||||||
return f"{parsed.scheme}://{host}:{parsed.port}"
|
return f"{parsed.scheme}://{host}:{parsed.port}"
|
||||||
return f"{parsed.scheme}://{host}"
|
return f"{parsed.scheme}://{host}"
|
||||||
except Exception:
|
except ValueError:
|
||||||
# If parsing fails, show nothing rather than leaking the raw URL
|
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
|
||||||
return "<unparseable>"
|
return "<unparseable>"
|
||||||
|
|
||||||
|
|
||||||
@@ -235,9 +235,9 @@ def _analyze_subscription_status(
|
|||||||
"""Analyze subscription status dict, returning error count and connection issues.
|
"""Analyze subscription status dict, returning error count and connection issues.
|
||||||
|
|
||||||
This is the canonical implementation of subscription status analysis.
|
This is the canonical implementation of subscription status analysis.
|
||||||
TODO: subscriptions/diagnostics.py (lines 168-182) duplicates this logic.
|
TODO: subscriptions/diagnostics.py has a similar status-analysis pattern
|
||||||
That module should be refactored to call this helper once file ownership
|
in diagnose_subscriptions(). That module could import and call this helper
|
||||||
allows cross-agent edits. See Code-H05.
|
directly to avoid divergence. See Code-H05.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
status: Dict of subscription name -> status info from get_subscription_status().
|
status: Dict of subscription name -> status info from get_subscription_status().
|
||||||
|
|||||||
Reference in New Issue
Block a user