From a5ed5aab5fe431fb55a4a1e9d297940efa2973d7 Mon Sep 17 00:00:00 2001 From: Jacob Magar Date: Fri, 13 Mar 2026 10:55:54 -0400 Subject: [PATCH] fix: update tests for confirm guard on update_ssh and field-based subscription allow-list --- tests/test_info.py | 34 ++++++++++++++++++++------- tests/test_subscription_validation.py | 32 ++++++++++++------------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/tests/test_info.py b/tests/test_info.py index cc2e910..a2256fb 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -291,37 +291,55 @@ class TestInfoMutations: await tool_fn(action="update_server") async def test_update_server_success(self, _mock_graphql: AsyncMock) -> None: - _mock_graphql.return_value = {"updateServerIdentity": {"id": "s:1", "name": "tootie", "comment": None, "status": "online"}} + _mock_graphql.return_value = { + "updateServerIdentity": { + "id": "s:1", + "name": "tootie", + "comment": None, + "status": "online", + } + } tool_fn = _make_tool() result = await tool_fn(action="update_server", server_name="tootie") assert result["success"] is True assert result["data"]["name"] == "tootie" async def test_update_server_passes_optional_fields(self, _mock_graphql: AsyncMock) -> None: - _mock_graphql.return_value = {"updateServerIdentity": {"id": "s:1", "name": "x", "comment": None, "status": "online"}} + _mock_graphql.return_value = { + "updateServerIdentity": {"id": "s:1", "name": "x", "comment": None, "status": "online"} + } tool_fn = _make_tool() await tool_fn(action="update_server", server_name="x", sys_model="custom") assert _mock_graphql.call_args[0][1]["sysModel"] == "custom" + async def test_update_ssh_requires_confirm(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="confirm=True"): + await tool_fn(action="update_ssh", ssh_enabled=True, ssh_port=22) + async def test_update_ssh_requires_enabled(self, _mock_graphql: AsyncMock) -> None: tool_fn = _make_tool() with pytest.raises(ToolError, match="ssh_enabled"): - await tool_fn(action="update_ssh", ssh_port=22) + await tool_fn(action="update_ssh", confirm=True, ssh_port=22) async def test_update_ssh_requires_port(self, _mock_graphql: AsyncMock) -> None: tool_fn = _make_tool() with pytest.raises(ToolError, match="ssh_port"): - await tool_fn(action="update_ssh", ssh_enabled=True) + await tool_fn(action="update_ssh", confirm=True, ssh_enabled=True) async def test_update_ssh_success(self, _mock_graphql: AsyncMock) -> None: - _mock_graphql.return_value = {"updateSshSettings": {"id": "s:1", "useSsh": True, "portssh": 22}} + _mock_graphql.return_value = { + "updateSshSettings": {"id": "s:1", "useSsh": True, "portssh": 22} + } tool_fn = _make_tool() - result = await tool_fn(action="update_ssh", ssh_enabled=True, ssh_port=22) + result = await tool_fn(action="update_ssh", confirm=True, ssh_enabled=True, ssh_port=22) assert result["success"] is True assert result["data"]["useSsh"] is True async def test_update_ssh_passes_correct_input(self, _mock_graphql: AsyncMock) -> None: - _mock_graphql.return_value = {"updateSshSettings": {"id": "s:1", "useSsh": False, "portssh": 2222}} + _mock_graphql.return_value = { + "updateSshSettings": {"id": "s:1", "useSsh": False, "portssh": 2222} + } tool_fn = _make_tool() - await tool_fn(action="update_ssh", ssh_enabled=False, ssh_port=2222) + await tool_fn(action="update_ssh", confirm=True, ssh_enabled=False, ssh_port=2222) assert _mock_graphql.call_args[0][1] == {"input": {"enabled": False, "port": 2222}} diff --git a/tests/test_subscription_validation.py b/tests/test_subscription_validation.py index 47a24d8..01b39b5 100644 --- a/tests/test_subscription_validation.py +++ b/tests/test_subscription_validation.py @@ -23,21 +23,21 @@ class TestValidateSubscriptionQueryAllowed: assert result == sub_name def test_returns_extracted_subscription_name(self) -> None: - query = "subscription { cpuSubscription { usage } }" - assert _validate_subscription_query(query) == "cpuSubscription" + query = "subscription { cpu { usage } }" + assert _validate_subscription_query(query) == "cpu" def test_leading_whitespace_accepted(self) -> None: - query = " subscription { memorySubscription { free } }" - assert _validate_subscription_query(query) == "memorySubscription" + query = " subscription { memory { free } }" + assert _validate_subscription_query(query) == "memory" def test_multiline_query_accepted(self) -> None: - query = "subscription {\n logFileSubscription {\n content\n }\n}" - assert _validate_subscription_query(query) == "logFileSubscription" + query = "subscription {\n logFile {\n content\n }\n}" + assert _validate_subscription_query(query) == "logFile" def test_case_insensitive_subscription_keyword(self) -> None: """'SUBSCRIPTION' should be accepted (regex uses IGNORECASE).""" - query = "SUBSCRIPTION { cpuSubscription { usage } }" - assert _validate_subscription_query(query) == "cpuSubscription" + query = "SUBSCRIPTION { cpu { usage } }" + assert _validate_subscription_query(query) == "cpu" class TestValidateSubscriptionQueryForbiddenKeywords: @@ -72,16 +72,16 @@ class TestValidateSubscriptionQueryForbiddenKeywords: def test_mutation_field_identifier_not_rejected(self) -> None: """'mutationField' as an identifier must NOT be rejected — only standalone 'mutation'.""" # This tests the \b word boundary in _FORBIDDEN_KEYWORDS - query = "subscription { cpuSubscription { mutationField } }" + query = "subscription { cpu { mutationField } }" # Should not raise — "mutationField" is an identifier, not the keyword result = _validate_subscription_query(query) - assert result == "cpuSubscription" + assert result == "cpu" def test_query_field_identifier_not_rejected(self) -> None: """'queryResult' as an identifier must NOT be rejected.""" - query = "subscription { cpuSubscription { queryResult } }" + query = "subscription { cpu { queryResult } }" result = _validate_subscription_query(query) - assert result == "cpuSubscription" + assert result == "cpu" class TestValidateSubscriptionQueryInvalidFormat: @@ -114,9 +114,9 @@ class TestValidateSubscriptionQueryUnknownName: _validate_subscription_query(query) def test_error_message_includes_allowed_list(self) -> None: - """Error message must list the allowed subscription names for usability.""" + """Error message must list the allowed subscription field names for usability.""" query = "subscription { badSub { data } }" - with pytest.raises(ToolError, match="Allowed subscriptions"): + with pytest.raises(ToolError, match="Allowed fields"): _validate_subscription_query(query) def test_arbitrary_field_name_rejected(self) -> None: @@ -125,7 +125,7 @@ class TestValidateSubscriptionQueryUnknownName: _validate_subscription_query(query) def test_close_but_not_whitelisted_rejected(self) -> None: - """'cpu' without 'Subscription' suffix is not in the allow-list.""" - query = "subscription { cpu { usage } }" + """'cpuSubscription' (old operation-style name) is not in the field allow-list.""" + query = "subscription { cpuSubscription { usage } }" with pytest.raises(ToolError, match="not allowed"): _validate_subscription_query(query)