From 1751bc29845d61ba084990c488f32a9894f44c2c Mon Sep 17 00:00:00 2001 From: Jacob Magar Date: Thu, 19 Feb 2026 02:23:04 -0500 Subject: [PATCH] fix: apply all PR review agent findings (silent failures, type safety, test gaps) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses issues found by 4 parallel review agents (code-reviewer, silent-failure-hunter, type-design-analyzer, pr-test-analyzer). Source fixes: - core/utils.py: add public safe_display_url() (moved from tools/health.py) - core/client.py: rename _redact_sensitive → redact_sensitive (public API) - core/types.py: add SubscriptionData.__post_init__ for tz-aware datetime enforcement; remove 6 unused type aliases (SystemHealth, APIResponse, etc.) - subscriptions/manager.py: add exc_info=True to both except-Exception blocks; add except ValueError break-on-config-error before retry loop; import redact_sensitive by new public name - subscriptions/resources.py: re-raise in autostart_subscriptions() so ensure_subscriptions_started() doesn't permanently set _subscriptions_started - subscriptions/diagnostics.py: except ToolError: raise before broad except; use safe_display_url() instead of raw URL slice - tools/health.py: move _safe_display_url to core/utils; add exc_info=True; raise ToolError (not return dict) on ImportError - tools/info.py: use get_args(INFO_ACTIONS) instead of INFO_ACTIONS.__args__ - tools/{array,docker,keys,notifications,rclone,storage,virtualization}.py: add Literal-vs-ALL_ACTIONS sync check at import time Test fixes: - test_health.py: import safe_display_url from core.utils; update test_diagnose_import_error_internal to expect ToolError (not error dict) - test_storage.py: add 3 safe_get tests for zero/False/empty-string values - test_subscription_manager.py: add TestCapLogContentSingleMassiveLine (2 tests) - test_client.py: rename _redact_sensitive → redact_sensitive; add tests for new sensitive keys and is_cacheable explicit-keyword form --- tests/integration/test_subscriptions.py | 1 + tests/test_client.py | 51 ++++++++++++++++------- tests/test_docker.py | 23 ++++++++++- tests/test_health.py | 48 +++++++++++----------- tests/test_info.py | 5 ++- tests/test_rclone.py | 1 - tests/test_storage.py | 9 +++++ tests/test_subscription_manager.py | 51 +++++++++++++++++------ unraid_mcp/__init__.py | 13 ++---- unraid_mcp/config/logging.py | 4 +- unraid_mcp/config/settings.py | 25 +++++++----- unraid_mcp/core/client.py | 18 ++++++--- unraid_mcp/core/exceptions.py | 11 +++-- unraid_mcp/core/types.py | 37 ++++------------- unraid_mcp/core/utils.py | 20 +++++++++ unraid_mcp/subscriptions/diagnostics.py | 5 ++- unraid_mcp/subscriptions/manager.py | 54 ++++++++++++++++++------- unraid_mcp/subscriptions/resources.py | 1 + unraid_mcp/tools/array.py | 10 ++++- unraid_mcp/tools/docker.py | 36 +++++++++++------ unraid_mcp/tools/health.py | 42 +++++++------------ unraid_mcp/tools/info.py | 11 ++--- unraid_mcp/tools/keys.py | 10 ++++- unraid_mcp/tools/notifications.py | 10 ++++- unraid_mcp/tools/rclone.py | 12 +++++- unraid_mcp/tools/storage.py | 12 +++++- unraid_mcp/tools/virtualization.py | 10 ++++- unraid_mcp/version.py | 11 +++++ 28 files changed, 354 insertions(+), 187 deletions(-) create mode 100644 unraid_mcp/version.py diff --git a/tests/integration/test_subscriptions.py b/tests/integration/test_subscriptions.py index 22e3954..755bfd7 100644 --- a/tests/integration/test_subscriptions.py +++ b/tests/integration/test_subscriptions.py @@ -16,6 +16,7 @@ import websockets.exceptions from unraid_mcp.subscriptions.manager import SubscriptionManager + pytestmark = pytest.mark.integration diff --git a/tests/test_client.py b/tests/test_client.py index c90f797..904409c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,9 +12,9 @@ from unraid_mcp.core.client import ( DISK_TIMEOUT, _QueryCache, _RateLimiter, - _redact_sensitive, is_idempotent_error, make_graphql_request, + redact_sensitive, ) from unraid_mcp.core.exceptions import ToolError @@ -60,7 +60,7 @@ class TestIsIdempotentError: # --------------------------------------------------------------------------- -# _redact_sensitive +# redact_sensitive # --------------------------------------------------------------------------- @@ -69,36 +69,36 @@ class TestRedactSensitive: def test_flat_dict(self) -> None: data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"} - result = _redact_sensitive(data) + result = redact_sensitive(data) assert result["username"] == "admin" assert result["password"] == "***" assert result["host"] == "10.0.0.1" def test_nested_dict(self) -> None: data = {"config": {"apiKey": "abc123", "url": "http://host"}} - result = _redact_sensitive(data) + result = redact_sensitive(data) assert result["config"]["apiKey"] == "***" assert result["config"]["url"] == "http://host" def test_list_of_dicts(self) -> None: data = [{"token": "t1"}, {"name": "safe"}] - result = _redact_sensitive(data) + result = redact_sensitive(data) assert result[0]["token"] == "***" assert result[1]["name"] == "safe" def test_deeply_nested(self) -> None: data = {"a": {"b": {"c": {"secret": "deep"}}}} - result = _redact_sensitive(data) + result = redact_sensitive(data) assert result["a"]["b"]["c"]["secret"] == "***" def test_non_dict_passthrough(self) -> None: - assert _redact_sensitive("plain_string") == "plain_string" - assert _redact_sensitive(42) == 42 - assert _redact_sensitive(None) is None + assert redact_sensitive("plain_string") == "plain_string" + assert redact_sensitive(42) == 42 + assert redact_sensitive(None) is None def test_case_insensitive_keys(self) -> None: data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"} - result = _redact_sensitive(data) + result = redact_sensitive(data) for v in result.values(): assert v == "***" @@ -112,7 +112,7 @@ class TestRedactSensitive: "username": "safe", "host": "safe", } - result = _redact_sensitive(data) + result = redact_sensitive(data) assert result["user_password"] == "***" assert result["api_key_value"] == "***" assert result["auth_token_expiry"] == "***" @@ -122,12 +122,26 @@ class TestRedactSensitive: def test_mixed_list_content(self) -> None: data = [{"key": "val"}, "string", 123, [{"token": "inner"}]] - result = _redact_sensitive(data) + result = redact_sensitive(data) assert result[0]["key"] == "***" assert result[1] == "string" assert result[2] == 123 assert result[3][0]["token"] == "***" + def test_new_sensitive_keys_are_redacted(self) -> None: + """PR-added keys: authorization, cookie, session, credential, passphrase, jwt.""" + data = { + "authorization": "Bearer token123", + "cookie": "session=abc", + "jwt": "eyJ...", + "credential": "secret_cred", + "passphrase": "hunter2", + "session": "sess_id", + } + result = redact_sensitive(data) + for key, val in result.items(): + assert val == "***", f"Key '{key}' was not redacted" + # --------------------------------------------------------------------------- # Timeout constants @@ -347,7 +361,7 @@ class TestMakeGraphQLRequestErrors: with ( patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), - pytest.raises(ToolError, match="invalid response.*not valid JSON"), + pytest.raises(ToolError, match=r"invalid response.*not valid JSON"), ): await make_graphql_request("{ info }") @@ -481,7 +495,7 @@ class TestRateLimiter: limiter = _RateLimiter(max_tokens=10, refill_rate=1.0) initial = limiter.tokens await limiter.acquire() - assert limiter.tokens == initial - 1 + assert limiter.tokens == pytest.approx(initial - 1, abs=1e-3) async def test_acquire_succeeds_when_tokens_available(self) -> None: limiter = _RateLimiter(max_tokens=5, refill_rate=1.0) @@ -596,6 +610,15 @@ class TestQueryCache: """Queries that start with 'mutation' after whitespace are not cacheable.""" assert _QueryCache.is_cacheable(" mutation { ... }") is False + def test_is_cacheable_with_explicit_query_keyword(self) -> None: + """Operation names after explicit 'query' keyword must be recognized.""" + assert _QueryCache.is_cacheable("query GetNetworkConfig { network { name } }") is True + assert _QueryCache.is_cacheable("query GetOwner { owner { name } }") is True + + def test_is_cacheable_anonymous_query_returns_false(self) -> None: + """Anonymous 'query { ... }' has no operation name — must not be cached.""" + assert _QueryCache.is_cacheable("query { network { name } }") is False + def test_expired_entry_removed_from_store(self) -> None: """Accessing an expired entry should remove it from the internal store.""" cache = _QueryCache() diff --git a/tests/test_docker.py b/tests/test_docker.py index c3591ff..5b045ed 100644 --- a/tests/test_docker.py +++ b/tests/test_docker.py @@ -80,6 +80,14 @@ class TestDockerValidation: with pytest.raises(ToolError, match="network_id"): await tool_fn(action="network_details") + async def test_non_logs_action_ignores_tail_lines_validation( + self, _mock_graphql: AsyncMock + ) -> None: + _mock_graphql.return_value = {"docker": {"containers": []}} + tool_fn = _make_tool() + result = await tool_fn(action="list", tail_lines=0) + assert result["containers"] == [] + class TestDockerActions: async def test_list(self, _mock_graphql: AsyncMock) -> None: @@ -224,9 +232,22 @@ class TestDockerActions: async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None: _mock_graphql.side_effect = RuntimeError("unexpected failure") tool_fn = _make_tool() - with pytest.raises(ToolError, match="unexpected failure"): + with pytest.raises(ToolError, match="Failed to execute docker/list"): await tool_fn(action="list") + async def test_short_id_prefix_ambiguous_rejected(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = { + "docker": { + "containers": [ + {"id": "abcdef1234560000000000000000000000000000000000000000000000000000:local", "names": ["plex"]}, + {"id": "abcdef1234561111111111111111111111111111111111111111111111111111:local", "names": ["sonarr"]}, + ] + } + } + tool_fn = _make_tool() + with pytest.raises(ToolError, match="ambiguous"): + await tool_fn(action="logs", container_id="abcdef123456") + class TestDockerMutationFailures: """Tests for mutation responses that indicate failure or unexpected shapes.""" diff --git a/tests/test_health.py b/tests/test_health.py index de2f835..8b58732 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -7,7 +7,7 @@ import pytest from conftest import make_tool_fn from unraid_mcp.core.exceptions import ToolError -from unraid_mcp.tools.health import _safe_display_url +from unraid_mcp.core.utils import safe_display_url @pytest.fixture @@ -100,7 +100,7 @@ class TestHealthActions: "unraid_mcp.tools.health._diagnose_subscriptions", side_effect=RuntimeError("broken"), ), - pytest.raises(ToolError, match="broken"), + pytest.raises(ToolError, match="Failed to execute health/diagnose"), ): await tool_fn(action="diagnose") @@ -115,7 +115,7 @@ class TestHealthActions: assert "cpu_sub" in result async def test_diagnose_import_error_internal(self) -> None: - """_diagnose_subscriptions catches ImportError and returns error dict.""" + """_diagnose_subscriptions raises ToolError when subscription modules are unavailable.""" import sys from unraid_mcp.tools.health import _diagnose_subscriptions @@ -127,16 +127,18 @@ class TestHealthActions: try: # Replace the modules with objects that raise ImportError on access - with patch.dict( - sys.modules, - { - "unraid_mcp.subscriptions": None, - "unraid_mcp.subscriptions.manager": None, - "unraid_mcp.subscriptions.resources": None, - }, + with ( + patch.dict( + sys.modules, + { + "unraid_mcp.subscriptions": None, + "unraid_mcp.subscriptions.manager": None, + "unraid_mcp.subscriptions.resources": None, + }, + ), + pytest.raises(ToolError, match="Subscription modules not available"), ): - result = await _diagnose_subscriptions() - assert "error" in result + await _diagnose_subscriptions() finally: # Restore cached modules sys.modules.update(cached) @@ -148,47 +150,47 @@ class TestHealthActions: class TestSafeDisplayUrl: - """Verify that _safe_display_url strips credentials/path and preserves scheme+host+port.""" + """Verify that safe_display_url strips credentials/path and preserves scheme+host+port.""" def test_none_returns_none(self) -> None: - assert _safe_display_url(None) is None + assert safe_display_url(None) is None def test_empty_string_returns_none(self) -> None: - assert _safe_display_url("") is None + assert safe_display_url("") is None def test_simple_url_scheme_and_host(self) -> None: - assert _safe_display_url("https://unraid.local/graphql") == "https://unraid.local" + assert safe_display_url("https://unraid.local/graphql") == "https://unraid.local" def test_preserves_port(self) -> None: - assert _safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337" + assert safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337" def test_strips_path(self) -> None: - result = _safe_display_url("http://unraid.local/some/deep/path?query=1") + result = safe_display_url("http://unraid.local/some/deep/path?query=1") assert "path" not in result assert "query" not in result def test_strips_credentials(self) -> None: - result = _safe_display_url("https://user:password@unraid.local/graphql") + result = safe_display_url("https://user:password@unraid.local/graphql") assert "user" not in result assert "password" not in result assert result == "https://unraid.local" def test_strips_query_params(self) -> None: - result = _safe_display_url("http://host.local?token=abc&key=xyz") + result = safe_display_url("http://host.local?token=abc&key=xyz") assert "token" not in result assert "abc" not in result def test_http_scheme_preserved(self) -> None: - result = _safe_display_url("http://10.0.0.1:8080/api") + result = safe_display_url("http://10.0.0.1:8080/api") assert result == "http://10.0.0.1:8080" def test_tailscale_url(self) -> None: - result = _safe_display_url("https://100.118.209.1:31337/graphql") + result = safe_display_url("https://100.118.209.1:31337/graphql") assert result == "https://100.118.209.1:31337" def test_malformed_ipv6_url_returns_unparseable(self) -> None: """Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError.""" # urlparse("https://[invalid") parses without error, but accessing .hostname # raises ValueError: Invalid IPv6 URL — this triggers the except branch. - result = _safe_display_url("https://[invalid") + result = safe_display_url("https://[invalid") assert result == "" diff --git a/tests/test_info.py b/tests/test_info.py index 02fc3ea..2f1c77c 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -186,7 +186,7 @@ class TestUnraidInfoTool: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: _mock_graphql.side_effect = RuntimeError("unexpected") tool_fn = _make_tool() - with pytest.raises(ToolError, match="unexpected"): + with pytest.raises(ToolError, match="Failed to execute info/online"): await tool_fn(action="online") async def test_metrics(self, _mock_graphql: AsyncMock) -> None: @@ -201,6 +201,7 @@ class TestUnraidInfoTool: _mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]} tool_fn = _make_tool() result = await tool_fn(action="services") + assert "services" in result assert len(result["services"]) == 1 assert result["services"][0]["name"] == "docker" @@ -225,6 +226,7 @@ class TestUnraidInfoTool: } tool_fn = _make_tool() result = await tool_fn(action="servers") + assert "servers" in result assert len(result["servers"]) == 1 assert result["servers"][0]["name"] == "tower" @@ -248,6 +250,7 @@ class TestUnraidInfoTool: } tool_fn = _make_tool() result = await tool_fn(action="ups_devices") + assert "ups_devices" in result assert len(result["ups_devices"]) == 1 assert result["ups_devices"][0]["model"] == "APC" diff --git a/tests/test_rclone.py b/tests/test_rclone.py index caf93cd..c5a7103 100644 --- a/tests/test_rclone.py +++ b/tests/test_rclone.py @@ -19,7 +19,6 @@ def _make_tool(): return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone") -@pytest.mark.usefixtures("_mock_graphql") class TestRcloneValidation: async def test_delete_requires_confirm(self) -> None: tool_fn = _make_tool() diff --git a/tests/test_storage.py b/tests/test_storage.py index f86e720..5f4ca7e 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -149,6 +149,15 @@ class TestSafeGet: result = safe_get({}, "missing", default=[]) assert result == [] + def test_zero_value_not_replaced_by_default(self) -> None: + assert safe_get({"temp": 0}, "temp", default="N/A") == 0 + + def test_false_value_not_replaced_by_default(self) -> None: + assert safe_get({"active": False}, "active", default=True) is False + + def test_empty_string_not_replaced_by_default(self) -> None: + assert safe_get({"name": ""}, "name", default="unknown") == "" + class TestStorageActions: async def test_shares(self, _mock_graphql: AsyncMock) -> None: diff --git a/tests/test_subscription_manager.py b/tests/test_subscription_manager.py index 53b5080..3c96794 100644 --- a/tests/test_subscription_manager.py +++ b/tests/test_subscription_manager.py @@ -60,8 +60,8 @@ class TestCapLogContentSmallData: class TestCapLogContentTruncation: """Content exceeding both byte AND line limits must be truncated to the last N lines.""" - def test_oversized_content_truncated_to_last_n_lines(self) -> None: - # 200 lines, limit 50 lines, byte limit effectively 0 → should keep last 50 lines + def test_oversized_content_truncated_and_byte_capped(self) -> None: + # 200 lines, tiny byte limit: must keep recent content within byte cap. lines = [f"line {i}" for i in range(200)] data = {"content": "\n".join(lines)} with ( @@ -70,14 +70,13 @@ class TestCapLogContentTruncation: ): result = _cap_log_content(data) result_lines = result["content"].splitlines() - assert len(result_lines) == 50 - # Must be the LAST 50 lines - assert result_lines[0] == "line 150" + assert len(result["content"].encode("utf-8", errors="replace")) <= 10 + # Must keep the most recent line suffix. assert result_lines[-1] == "line 199" - def test_content_with_fewer_lines_than_limit_not_truncated(self) -> None: - """If byte limit exceeded but line count ≤ limit → keep original (not truncated).""" - # 30 lines but byte limit 10 and line limit 50 → 30 < 50 so no truncation + def test_content_with_fewer_lines_than_limit_still_honors_byte_cap(self) -> None: + """If byte limit is exceeded, output must still be capped even with few lines.""" + # 30 lines, byte limit 10, line limit 50 -> must cap bytes regardless of line count lines = [f"line {i}" for i in range(30)] data = {"content": "\n".join(lines)} with ( @@ -85,8 +84,7 @@ class TestCapLogContentTruncation: patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), ): result = _cap_log_content(data) - # Original content preserved - assert result["content"] == data["content"] + assert len(result["content"].encode("utf-8", errors="replace")) <= 10 def test_non_content_keys_preserved_alongside_truncated_content(self) -> None: lines = [f"line {i}" for i in range(200)] @@ -98,7 +96,7 @@ class TestCapLogContentTruncation: result = _cap_log_content(data) assert result["path"] == "/var/log/syslog" assert result["total_lines"] == 200 - assert len(result["content"].splitlines()) == 50 + assert len(result["content"].encode("utf-8", errors="replace")) <= 10 class TestCapLogContentNested: @@ -112,7 +110,7 @@ class TestCapLogContentNested: patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), ): result = _cap_log_content(data) - assert len(result["logFile"]["content"].splitlines()) == 50 + assert len(result["logFile"]["content"].encode("utf-8", errors="replace")) <= 10 assert result["logFile"]["path"] == "/var/log/syslog" def test_deeply_nested_content_capped(self) -> None: @@ -123,9 +121,36 @@ class TestCapLogContentNested: patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50), ): result = _cap_log_content(data) - assert len(result["outer"]["inner"]["content"].splitlines()) == 50 + assert len(result["outer"]["inner"]["content"].encode("utf-8", errors="replace")) <= 10 def test_nested_non_content_keys_unaffected(self) -> None: data = {"metrics": {"cpu": 42.5, "memory": 8192}} result = _cap_log_content(data) assert result == data + + +class TestCapLogContentSingleMassiveLine: + """A single line larger than the byte cap must be hard-capped at byte level.""" + + def test_single_massive_line_hard_caps_bytes(self) -> None: + # One line, no newlines, larger than the byte cap. + # The while-loop can't reduce it (len(lines) == 1), so the + # last-resort byte-slice path at manager.py:65-69 must fire. + huge_content = "x" * 200 + data = {"content": huge_content} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000), + ): + result = _cap_log_content(data) + assert len(result["content"].encode("utf-8", errors="replace")) <= 10 + + def test_single_massive_line_input_not_mutated(self) -> None: + huge_content = "x" * 200 + data = {"content": huge_content} + with ( + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10), + patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000), + ): + _cap_log_content(data) + assert data["content"] == huge_content diff --git a/unraid_mcp/__init__.py b/unraid_mcp/__init__.py index b6d6c59..a07b7be 100644 --- a/unraid_mcp/__init__.py +++ b/unraid_mcp/__init__.py @@ -1,13 +1,6 @@ -"""Unraid MCP Server Package. +"""Unraid MCP Server Package.""" -A modular MCP (Model Context Protocol) server that provides tools to interact -with an Unraid server's GraphQL API. -""" - -from importlib.metadata import PackageNotFoundError, version +from .version import VERSION -try: - __version__ = version("unraid-mcp") -except PackageNotFoundError: - __version__ = "0.0.0" +__version__ = VERSION diff --git a/unraid_mcp/config/logging.py b/unraid_mcp/config/logging.py index 0df21c6..f0193d8 100644 --- a/unraid_mcp/config/logging.py +++ b/unraid_mcp/config/logging.py @@ -47,7 +47,7 @@ class OverwriteFileHandler(logging.FileHandler): """Emit a record, checking file size periodically and overwriting if needed.""" self._emit_count += 1 if ( - self._emit_count % self._check_interval == 0 + (self._emit_count == 1 or self._emit_count % self._check_interval == 0) and self.stream and hasattr(self.stream, "name") ): @@ -249,5 +249,3 @@ if FASTMCP_AVAILABLE: else: # Fallback to our custom logger if FastMCP is not available logger = setup_logger() - # Also configure FastMCP logger for consistency - configure_fastmcp_logger_with_rich() diff --git a/unraid_mcp/config/settings.py b/unraid_mcp/config/settings.py index cdea8b6..1478199 100644 --- a/unraid_mcp/config/settings.py +++ b/unraid_mcp/config/settings.py @@ -5,12 +5,13 @@ and provides all configuration constants used throughout the application. """ import os -from importlib.metadata import PackageNotFoundError, version from pathlib import Path from typing import Any from dotenv import load_dotenv +from ..version import VERSION as APP_VERSION + # Get the script directory (config module location) SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/ @@ -31,12 +32,6 @@ for dotenv_path in dotenv_paths: load_dotenv(dotenv_path=dotenv_path) break -# Application Version (single source of truth: pyproject.toml) -try: - VERSION = version("unraid-mcp") -except PackageNotFoundError: - VERSION = "0.0.0" - # Core API Configuration UNRAID_API_URL = os.getenv("UNRAID_API_URL") UNRAID_API_KEY = os.getenv("UNRAID_API_KEY") @@ -58,12 +53,18 @@ else: # Path to CA bundle # Logging Configuration LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper() LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log") -# Use /app/logs in Docker, project-relative logs/ directory otherwise -LOGS_DIR = Path("/app/logs") if Path("/app").is_dir() else PROJECT_ROOT / "logs" +# Use /.dockerenv as the container indicator for robust Docker detection. +IS_DOCKER = Path("/.dockerenv").exists() +LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs" LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME -# Ensure logs directory exists -LOGS_DIR.mkdir(parents=True, exist_ok=True) +# Ensure logs directory exists; if creation fails, fall back to /tmp. +try: + LOGS_DIR.mkdir(parents=True, exist_ok=True) +except OSError: + LOGS_DIR = PROJECT_ROOT / ".cache" / "logs" + LOGS_DIR.mkdir(parents=True, exist_ok=True) + LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME # HTTP Client Configuration TIMEOUT_CONFIG = { @@ -109,3 +110,5 @@ def get_config_summary() -> dict[str, Any]: "config_valid": is_valid, "missing_config": missing if not is_valid else None, } +# Re-export application version from a single source of truth. +VERSION = APP_VERSION diff --git a/unraid_mcp/core/client.py b/unraid_mcp/core/client.py index ea568cf..eb452d6 100644 --- a/unraid_mcp/core/client.py +++ b/unraid_mcp/core/client.py @@ -7,6 +7,7 @@ to the Unraid API with proper timeout handling and error management. import asyncio import hashlib import json +import re import time from typing import Any, Final @@ -47,14 +48,14 @@ def _is_sensitive_key(key: str) -> bool: return any(s in key_lower for s in _SENSITIVE_KEYS) -def _redact_sensitive(obj: Any) -> Any: +def redact_sensitive(obj: Any) -> Any: """Recursively redact sensitive values from nested dicts/lists.""" if isinstance(obj, dict): return { - k: ("***" if _is_sensitive_key(k) else _redact_sensitive(v)) for k, v in obj.items() + k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items() } if isinstance(obj, list): - return [_redact_sensitive(item) for item in obj] + return [redact_sensitive(item) for item in obj] return obj @@ -139,6 +140,7 @@ _CACHEABLE_QUERY_PREFIXES = frozenset( ) _CACHE_TTL_SECONDS = 60.0 +_OPERATION_NAME_PATTERN = re.compile(r"^(?:query\s+)?([_A-Za-z][_0-9A-Za-z]*)\b") class _QueryCache: @@ -160,9 +162,13 @@ class _QueryCache: @staticmethod def is_cacheable(query: str) -> bool: """Check if a query is eligible for caching based on its operation name.""" - if query.lstrip().startswith("mutation"): + normalized = query.lstrip() + if normalized.startswith("mutation"): return False - return any(prefix in query for prefix in _CACHEABLE_QUERY_PREFIXES) + match = _OPERATION_NAME_PATTERN.match(normalized) + if not match: + return False + return match.group(1) in _CACHEABLE_QUERY_PREFIXES def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None: """Return cached result if present and not expired, else None.""" @@ -324,7 +330,7 @@ async def make_graphql_request( logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:") logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query if variables: - logger.debug(f"Variables: {_redact_sensitive(variables)}") + logger.debug(f"Variables: {redact_sensitive(variables)}") try: # Rate limit: consume a token before making the request diff --git a/unraid_mcp/core/exceptions.py b/unraid_mcp/core/exceptions.py index 34d84d1..7737155 100644 --- a/unraid_mcp/core/exceptions.py +++ b/unraid_mcp/core/exceptions.py @@ -45,13 +45,12 @@ def tool_error_handler( except ToolError: raise except TimeoutError as e: - logger.error( - f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit", - exc_info=True, - ) + logger.exception(f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit") raise ToolError( f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time." ) from e except Exception as e: - logger.error(f"Error in unraid_{tool_name} action={action}: {e}", exc_info=True) - raise ToolError(f"Failed to execute {tool_name}/{action}: {e!s}") from e + logger.exception(f"Error in unraid_{tool_name} action={action}") + raise ToolError( + f"Failed to execute {tool_name}/{action}. Check server logs for details." + ) from e diff --git a/unraid_mcp/core/types.py b/unraid_mcp/core/types.py index 9b7ec8a..dc6ad0d 100644 --- a/unraid_mcp/core/types.py +++ b/unraid_mcp/core/types.py @@ -20,33 +20,10 @@ class SubscriptionData: last_updated: datetime # Must be timezone-aware (UTC) subscription_type: str - -@dataclass(slots=True) -class SystemHealth: - """Container for system health status information. - - Note: last_checked must be timezone-aware (use datetime.now(UTC)). - """ - - is_healthy: bool - issues: list[str] - warnings: list[str] - last_checked: datetime # Must be timezone-aware (UTC) - component_status: dict[str, str] - - -@dataclass(slots=True) -class APIResponse: - """Container for standardized API response data.""" - - success: bool - data: dict[str, Any] | None = None - error: str | None = None - metadata: dict[str, Any] | None = None - - -# Type aliases for common data structures -ConfigValue = str | int | bool | float | None -ConfigDict = dict[str, ConfigValue] -GraphQLVariables = dict[str, Any] -HealthStatus = dict[str, str | bool | int | list[Any]] + def __post_init__(self) -> None: + if self.last_updated.tzinfo is None: + raise ValueError( + "last_updated must be timezone-aware; use datetime.now(UTC)" + ) + if not self.subscription_type.strip(): + raise ValueError("subscription_type must be a non-empty string") diff --git a/unraid_mcp/core/utils.py b/unraid_mcp/core/utils.py index 5b5ec9b..fd02ed2 100644 --- a/unraid_mcp/core/utils.py +++ b/unraid_mcp/core/utils.py @@ -1,6 +1,7 @@ """Shared utility functions for Unraid MCP tools.""" from typing import Any +from urllib.parse import urlparse def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any: @@ -45,6 +46,25 @@ def format_bytes(bytes_value: int | None) -> str: return f"{value:.2f} EB" +def safe_display_url(url: str | None) -> str | None: + """Return a redacted URL showing only scheme + host + port. + + Strips path, query parameters, credentials, and fragments to avoid + leaking internal network topology or embedded secrets (CWE-200). + """ + if not url: + return None + try: + parsed = urlparse(url) + host = parsed.hostname or "unknown" + if parsed.port: + return f"{parsed.scheme}://{host}:{parsed.port}" + return f"{parsed.scheme}://{host}" + except ValueError: + # urlparse raises ValueError for invalid URLs (e.g. contains control chars) + return "" + + def format_kb(k: Any) -> str: """Format kilobyte values into human-readable sizes. diff --git a/unraid_mcp/subscriptions/diagnostics.py b/unraid_mcp/subscriptions/diagnostics.py index f72d010..b9dbcad 100644 --- a/unraid_mcp/subscriptions/diagnostics.py +++ b/unraid_mcp/subscriptions/diagnostics.py @@ -19,6 +19,7 @@ from websockets.typing import Subprotocol from ..config.logging import logger from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL from ..core.exceptions import ToolError +from ..core.utils import safe_display_url from .manager import subscription_manager from .resources import ensure_subscriptions_started from .utils import build_ws_ssl_context, build_ws_url @@ -162,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: "note": "Connection successful, subscription may be waiting for events", } + except ToolError: + raise except Exception as e: logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True) return {"error": str(e), "query_tested": subscription_query} @@ -193,7 +196,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None: "environment": { "auto_start_enabled": subscription_manager.auto_start_enabled, "max_reconnect_attempts": subscription_manager.max_reconnect_attempts, - "unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None, + "unraid_api_url": safe_display_url(UNRAID_API_URL), "api_key_configured": bool(UNRAID_API_KEY), "websocket_url": None, }, diff --git a/unraid_mcp/subscriptions/manager.py b/unraid_mcp/subscriptions/manager.py index 0416e2e..c58cd5a 100644 --- a/unraid_mcp/subscriptions/manager.py +++ b/unraid_mcp/subscriptions/manager.py @@ -17,7 +17,7 @@ from websockets.typing import Subprotocol from ..config.logging import logger from ..config.settings import UNRAID_API_KEY -from ..core.client import _redact_sensitive +from ..core.client import redact_sensitive from ..core.types import SubscriptionData from .utils import build_ws_ssl_context, build_ws_url @@ -36,8 +36,7 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: field (from log subscriptions) exceeds the byte limit, truncate it to the most recent _MAX_RESOURCE_DATA_LINES lines. - Note: single lines larger than _MAX_RESOURCE_DATA_BYTES are not split and - will still be stored at full size; only multi-line content is truncated. + The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES. """ result: dict[str, Any] = {} for key, value in data.items(): @@ -49,15 +48,31 @@ def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]: and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES ): lines = value.splitlines() + original_line_count = len(lines) + + # Keep most recent lines first. if len(lines) > _MAX_RESOURCE_DATA_LINES: - truncated = "\n".join(lines[-_MAX_RESOURCE_DATA_LINES:]) - logger.warning( - f"[RESOURCE] Capped log content from {len(lines)} to " - f"{_MAX_RESOURCE_DATA_LINES} lines ({len(value)} -> {len(truncated)} chars)" + lines = lines[-_MAX_RESOURCE_DATA_LINES:] + + # Enforce byte cap while preserving whole-line boundaries where possible. + truncated = "\n".join(lines) + truncated_bytes = truncated.encode("utf-8", errors="replace") + while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES: + lines = lines[1:] + truncated = "\n".join(lines) + truncated_bytes = truncated.encode("utf-8", errors="replace") + + # Last resort: if a single line still exceeds cap, hard-cap bytes. + if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES: + truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode( + "utf-8", errors="ignore" ) - result[key] = truncated - else: - result[key] = value + + logger.warning( + f"[RESOURCE] Capped log content from {original_line_count} to " + f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)" + ) + result[key] = truncated else: result[key] = value return result @@ -148,6 +163,7 @@ class SubscriptionManager: # Reset connection tracking self.reconnect_attempts[subscription_name] = 0 self.connection_states[subscription_name] = "starting" + self._connection_start_times.pop(subscription_name, None) async with self.subscription_lock: try: @@ -181,6 +197,7 @@ class SubscriptionManager: logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully") del self.active_subscriptions[subscription_name] self.connection_states[subscription_name] = "stopped" + self._connection_start_times.pop(subscription_name, None) logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped") else: logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop") @@ -322,7 +339,7 @@ class SubscriptionManager: ) logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...") logger.debug( - f"[SUBSCRIPTION:{subscription_name}] Variables: {_redact_sensitive(variables)}" + f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}" ) await websocket.send(json.dumps(subscription_message)) @@ -431,7 +448,8 @@ class SubscriptionManager: logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}") except Exception as e: logger.error( - f"[DATA:{subscription_name}] Error processing message: {e}" + f"[DATA:{subscription_name}] Error processing message: {e}", + exc_info=True, ) msg_preview = ( message[:200] @@ -461,14 +479,22 @@ class SubscriptionManager: self.connection_states[subscription_name] = "invalid_uri" break # Don't retry on invalid URI + except ValueError as e: + # Non-retryable configuration error (e.g. UNRAID_API_URL not set) + error_msg = f"Configuration error: {e}" + logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}") + self.last_error[subscription_name] = error_msg + self.connection_states[subscription_name] = "error" + break # Don't retry on configuration errors + except Exception as e: error_msg = f"Unexpected error: {e}" - logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}") + logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True) self.last_error[subscription_name] = error_msg self.connection_states[subscription_name] = "error" # Check if connection was stable before deciding on retry behavior - start_time = self._connection_start_times.get(subscription_name) + start_time = self._connection_start_times.pop(subscription_name, None) if start_time is not None: connected_duration = time.monotonic() - start_time if connected_duration >= _STABLE_CONNECTION_SECONDS: diff --git a/unraid_mcp/subscriptions/resources.py b/unraid_mcp/subscriptions/resources.py index f80a708..850ac1c 100644 --- a/unraid_mcp/subscriptions/resources.py +++ b/unraid_mcp/subscriptions/resources.py @@ -44,6 +44,7 @@ async def autostart_subscriptions() -> None: logger.info("[AUTOSTART] Auto-start process completed successfully") except Exception as e: logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True) + raise # Propagate so ensure_subscriptions_started doesn't mark as started # Optional log file subscription log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH") diff --git a/unraid_mcp/tools/array.py b/unraid_mcp/tools/array.py index 0afe755..85fe93b 100644 --- a/unraid_mcp/tools/array.py +++ b/unraid_mcp/tools/array.py @@ -3,7 +3,7 @@ Provides the `unraid_array` tool with 5 actions for parity check management. """ -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -53,6 +53,14 @@ ARRAY_ACTIONS = Literal[ "parity_status", ] +if set(get_args(ARRAY_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(ARRAY_ACTIONS)) + _extra = set(get_args(ARRAY_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"ARRAY_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + def register_array_tool(mcp: FastMCP) -> None: """Register the unraid_array tool with the FastMCP instance.""" diff --git a/unraid_mcp/tools/docker.py b/unraid_mcp/tools/docker.py index cd31e7b..b125551 100644 --- a/unraid_mcp/tools/docker.py +++ b/unraid_mcp/tools/docker.py @@ -5,7 +5,7 @@ logs, networks, and update management. """ import re -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -135,6 +135,14 @@ DOCKER_ACTIONS = Literal[ "check_updates", ] +if set(get_args(DOCKER_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(DOCKER_ACTIONS)) + _extra = set(get_args(DOCKER_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"DOCKER_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + # Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local") _DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE) @@ -199,11 +207,6 @@ def get_available_container_names(containers: list[dict[str, Any]]) -> list[str] return names -def _looks_like_container_id(identifier: str) -> bool: - """Check if an identifier looks like a container ID (full or short hex prefix).""" - return bool(_DOCKER_ID_PATTERN.match(identifier) or _DOCKER_SHORT_ID_PATTERN.match(identifier)) - - async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str: """Resolve a container name/identifier to its actual PrefixedID. @@ -233,12 +236,21 @@ async def _resolve_container_id(container_id: str, *, strict: bool = False) -> s # Short hex prefix: match by ID prefix before trying name matching if _DOCKER_SHORT_ID_PATTERN.match(container_id): id_lower = container_id.lower() + matches: list[dict[str, Any]] = [] for c in containers: cid = (c.get("id") or "").lower() if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower): - actual_id = str(c.get("id", "")) - logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'") - return actual_id + matches.append(c) + if len(matches) == 1: + actual_id = str(matches[0].get("id", "")) + logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'") + return actual_id + if len(matches) > 1: + candidate_ids = [str(c.get("id", "")) for c in matches[:5]] + raise ToolError( + f"Short container ID prefix '{container_id}' is ambiguous. " + f"Matches: {', '.join(candidate_ids)}. Use a longer ID or exact name." + ) resolved = find_container_by_identifier(container_id, containers, strict=strict) if resolved: @@ -303,7 +315,7 @@ def register_docker_tool(mcp: FastMCP) -> None: if action == "network_details" and not network_id: raise ToolError("network_id is required for 'network_details' action") - if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES: + if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES): raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}") with tool_error_handler("docker", action, logger): @@ -335,12 +347,12 @@ def register_docker_tool(mcp: FastMCP) -> None: if action == "networks": data = await make_graphql_request(QUERIES["networks"]) - networks = data.get("dockerNetworks", []) + networks = safe_get(data, "dockerNetworks", default=[]) return {"networks": networks} if action == "network_details": data = await make_graphql_request(QUERIES["network_details"], {"id": network_id}) - return dict(data.get("dockerNetwork") or {}) + return dict(safe_get(data, "dockerNetwork", default={}) or {}) if action == "port_conflicts": data = await make_graphql_request(QUERIES["port_conflicts"]) diff --git a/unraid_mcp/tools/health.py b/unraid_mcp/tools/health.py index dc34559..e025171 100644 --- a/unraid_mcp/tools/health.py +++ b/unraid_mcp/tools/health.py @@ -6,8 +6,7 @@ connection testing, and subscription diagnostics. import datetime import time -from typing import Any, Literal -from urllib.parse import urlparse +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -21,31 +20,21 @@ from ..config.settings import ( ) from ..core.client import make_graphql_request from ..core.exceptions import ToolError, tool_error_handler - - -def _safe_display_url(url: str | None) -> str | None: - """Return a redacted URL showing only scheme + host + port. - - Strips path, query parameters, credentials, and fragments to avoid - leaking internal network topology or embedded secrets (CWE-200). - """ - if not url: - return None - try: - parsed = urlparse(url) - host = parsed.hostname or "unknown" - if parsed.port: - return f"{parsed.scheme}://{host}:{parsed.port}" - return f"{parsed.scheme}://{host}" - except ValueError: - # urlparse raises ValueError for invalid URLs (e.g. contains control chars) - return "" +from ..core.utils import safe_display_url ALL_ACTIONS = {"check", "test_connection", "diagnose"} HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"] +if set(get_args(HEALTH_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(HEALTH_ACTIONS)) + _extra = set(get_args(HEALTH_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + "HEALTH_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing in HEALTH_ACTIONS: {_missing}; extra in HEALTH_ACTIONS: {_extra}" + ) + # Severity ordering: only upgrade, never downgrade _SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3} @@ -149,7 +138,7 @@ async def _comprehensive_check() -> dict[str, Any]: if info: health_info["unraid_system"] = { "status": "connected", - "url": _safe_display_url(UNRAID_API_URL), + "url": safe_display_url(UNRAID_API_URL), "machine_id": info.get("machineId"), "version": info.get("versions", {}).get("unraid"), "uptime": info.get("os", {}).get("uptime"), @@ -220,7 +209,7 @@ async def _comprehensive_check() -> dict[str, Any]: except Exception as e: # Intentionally broad: health checks must always return a result, # even on unexpected failures, so callers never get an unhandled exception. - logger.error(f"Health check failed: {e}") + logger.error(f"Health check failed: {e}", exc_info=True) return { "status": "unhealthy", "timestamp": datetime.datetime.now(datetime.UTC).isoformat(), @@ -293,10 +282,7 @@ async def _diagnose_subscriptions() -> dict[str, Any]: }, } - except ImportError: - return { - "error": "Subscription modules not available", - "timestamp": datetime.datetime.now(datetime.UTC).isoformat(), - } + except ImportError as e: + raise ToolError("Subscription modules not available") from e except Exception as e: raise ToolError(f"Failed to generate diagnostics: {e!s}") from e diff --git a/unraid_mcp/tools/info.py b/unraid_mcp/tools/info.py index b1287bb..75ff1d5 100644 --- a/unraid_mcp/tools/info.py +++ b/unraid_mcp/tools/info.py @@ -4,7 +4,7 @@ Provides the `unraid_info` tool with 19 read-only actions for retrieving system information, array status, network config, and server metadata. """ -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -180,9 +180,9 @@ INFO_ACTIONS = Literal[ "ups_config", ] -if set(INFO_ACTIONS.__args__) != ALL_ACTIONS: - _missing = ALL_ACTIONS - set(INFO_ACTIONS.__args__) - _extra = set(INFO_ACTIONS.__args__) - ALL_ACTIONS +if set(get_args(INFO_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(INFO_ACTIONS)) + _extra = set(get_args(INFO_ACTIONS)) - ALL_ACTIONS raise RuntimeError( f"QUERIES keys and INFO_ACTIONS are out of sync. " f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" @@ -415,7 +415,8 @@ def register_info_tool(mcp: FastMCP) -> None: if action in list_actions: response_key, output_key = list_actions[action] items = data.get(response_key) or [] - return {output_key: items} + normalized_items = list(items) if isinstance(items, list) else [] + return {output_key: normalized_items} raise ToolError(f"Unhandled action '{action}' — this is a bug") diff --git a/unraid_mcp/tools/keys.py b/unraid_mcp/tools/keys.py index be9c539..191c970 100644 --- a/unraid_mcp/tools/keys.py +++ b/unraid_mcp/tools/keys.py @@ -4,7 +4,7 @@ Provides the `unraid_keys` tool with 5 actions for listing, viewing, creating, updating, and deleting API keys. """ -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -55,6 +55,14 @@ KEY_ACTIONS = Literal[ "delete", ] +if set(get_args(KEY_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(KEY_ACTIONS)) + _extra = set(get_args(KEY_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"KEY_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + def register_keys_tool(mcp: FastMCP) -> None: """Register the unraid_keys tool with the FastMCP instance.""" diff --git a/unraid_mcp/tools/notifications.py b/unraid_mcp/tools/notifications.py index 0df7e2a..3053a13 100644 --- a/unraid_mcp/tools/notifications.py +++ b/unraid_mcp/tools/notifications.py @@ -4,7 +4,7 @@ Provides the `unraid_notifications` tool with 9 actions for viewing, creating, archiving, and deleting system notifications. """ -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -91,6 +91,14 @@ NOTIFICATION_ACTIONS = Literal[ "archive_all", ] +if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(NOTIFICATION_ACTIONS)) + _extra = set(get_args(NOTIFICATION_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"NOTIFICATION_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + def register_notifications_tool(mcp: FastMCP) -> None: """Register the unraid_notifications tool with the FastMCP instance.""" diff --git a/unraid_mcp/tools/rclone.py b/unraid_mcp/tools/rclone.py index 7c091cd..a9af93e 100644 --- a/unraid_mcp/tools/rclone.py +++ b/unraid_mcp/tools/rclone.py @@ -5,7 +5,7 @@ cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.). """ import re -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -50,10 +50,18 @@ RCLONE_ACTIONS = Literal[ "delete_remote", ] +if set(get_args(RCLONE_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(RCLONE_ACTIONS)) + _extra = set(get_args(RCLONE_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"RCLONE_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + # Max config entries to prevent abuse _MAX_CONFIG_KEYS = 50 # Pattern for suspicious key names (path traversal, shell metacharacters) -_DANGEROUS_KEY_PATTERN = re.compile(r"[.]{2}|[/\\;|`$(){}]") +_DANGEROUS_KEY_PATTERN = re.compile(r"\.\.|[/\\;|`$(){}]") # Max length for individual config values _MAX_VALUE_LENGTH = 4096 diff --git a/unraid_mcp/tools/storage.py b/unraid_mcp/tools/storage.py index 125595c..a50049e 100644 --- a/unraid_mcp/tools/storage.py +++ b/unraid_mcp/tools/storage.py @@ -5,7 +5,7 @@ unassigned devices, log files, and log content retrieval. """ import os -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -69,6 +69,14 @@ STORAGE_ACTIONS = Literal[ "logs", ] +if set(get_args(STORAGE_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(STORAGE_ACTIONS)) + _extra = set(get_args(STORAGE_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"STORAGE_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + def register_storage_tool(mcp: FastMCP) -> None: """Register the unraid_storage tool with the FastMCP instance.""" @@ -96,7 +104,7 @@ def register_storage_tool(mcp: FastMCP) -> None: if action == "disk_details" and not disk_id: raise ToolError("disk_id is required for 'disk_details' action") - if tail_lines < 1 or tail_lines > _MAX_TAIL_LINES: + if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES): raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}") if action == "logs": diff --git a/unraid_mcp/tools/virtualization.py b/unraid_mcp/tools/virtualization.py index baa421a..89166b5 100644 --- a/unraid_mcp/tools/virtualization.py +++ b/unraid_mcp/tools/virtualization.py @@ -4,7 +4,7 @@ Provides the `unraid_vm` tool with 9 actions for VM lifecycle management including start, stop, pause, resume, force stop, reboot, and reset. """ -from typing import Any, Literal +from typing import Any, Literal, get_args from fastmcp import FastMCP @@ -73,6 +73,14 @@ VM_ACTIONS = Literal[ ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) +if set(get_args(VM_ACTIONS)) != ALL_ACTIONS: + _missing = ALL_ACTIONS - set(get_args(VM_ACTIONS)) + _extra = set(get_args(VM_ACTIONS)) - ALL_ACTIONS + raise RuntimeError( + f"VM_ACTIONS and ALL_ACTIONS are out of sync. " + f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" + ) + def register_vm_tool(mcp: FastMCP) -> None: """Register the unraid_vm tool with the FastMCP instance.""" diff --git a/unraid_mcp/version.py b/unraid_mcp/version.py new file mode 100644 index 0000000..b97e207 --- /dev/null +++ b/unraid_mcp/version.py @@ -0,0 +1,11 @@ +"""Application version helpers.""" + +from importlib.metadata import PackageNotFoundError, version + + +__all__ = ["VERSION"] + +try: + VERSION = version("unraid-mcp") +except PackageNotFoundError: + VERSION = "0.0.0"