feat(elicitation): auto-elicit credentials on CredentialsNotConfiguredError in unraid_info

This commit is contained in:
Jacob Magar
2026-03-14 04:07:51 -04:00
parent 9be46750b8
commit 49264550b1
4 changed files with 133 additions and 23 deletions

View File

@@ -31,8 +31,12 @@ def make_tool_fn(
This wraps the repeated pattern of creating a test FastMCP instance, This wraps the repeated pattern of creating a test FastMCP instance,
registering a tool, and extracting the inner function. Centralizing registering a tool, and extracting the inner function. Centralizing
this avoids reliance on FastMCP's private `_tool_manager._tools` API this avoids reliance on FastMCP's internal tool storage API in every
in every test file. test file.
FastMCP 3.x removed `_tool_manager._tools`; use `await mcp.get_tool()`
instead. We run a small event loop here to keep the helper synchronous
so callers don't need to change.
Args: Args:
module_path: Dotted import path to the tool module (e.g., "unraid_mcp.tools.info") module_path: Dotted import path to the tool module (e.g., "unraid_mcp.tools.info")
@@ -48,4 +52,8 @@ def make_tool_fn(
register_fn = getattr(module, register_fn_name) register_fn = getattr(module, register_fn_name)
test_mcp = FastMCP("test") test_mcp = FastMCP("test")
register_fn(test_mcp) register_fn(test_mcp)
return test_mcp._tool_manager._tools[tool_name].fn # type: ignore[union-attr] # FastMCP 3.x stores tools in providers[0]._components keyed as "tool:{name}@"
# (the "@" suffix is the version separator with no version set).
local_provider = test_mcp.providers[0]
tool = local_provider._components[f"tool:{tool_name}@"]
return tool.fn

View File

@@ -170,18 +170,37 @@ class TestUnraidInfoTool:
await tool_fn(action="ups_device") await tool_fn(action="ups_device")
async def test_network_action(self, _mock_graphql: AsyncMock) -> None: async def test_network_action(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {"network": {"id": "net:1", "accessUrls": []}}
tool_fn = _make_tool()
result = await tool_fn(action="network")
assert result["id"] == "net:1"
async def test_connect_action(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = { _mock_graphql.return_value = {
"connect": {"status": "connected", "sandbox": False, "flashGuid": "abc123"} "servers": [
{
"id": "s:1",
"name": "tootie",
"status": "ONLINE",
"lanip": "10.1.0.2",
"wanip": "",
"localurl": "http://10.1.0.2:6969",
"remoteurl": "",
}
],
"vars": {
"id": "v:1",
"port": 6969,
"portssl": 31337,
"localTld": "local",
"useSsl": None,
},
} }
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="connect") result = await tool_fn(action="network")
assert result["status"] == "connected" assert "accessUrls" in result
assert result["httpPort"] == 6969
assert result["httpsPort"] == 31337
assert any(u["type"] == "LAN" and u["ipv4"] == "10.1.0.2" for u in result["accessUrls"])
async def test_connect_action_raises_tool_error(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="connect.*not available"):
await tool_fn(action="connect")
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("unexpected") _mock_graphql.side_effect = RuntimeError("unexpected")

View File

@@ -174,3 +174,42 @@ async def test_make_graphql_request_raises_sentinel_when_unconfigured():
finally: finally:
settings_mod.UNRAID_API_URL = original_url settings_mod.UNRAID_API_URL = original_url
settings_mod.UNRAID_API_KEY = original_key settings_mod.UNRAID_API_KEY = original_key
@pytest.mark.asyncio
async def test_auto_elicitation_triggered_on_credentials_not_configured():
"""Any tool call with missing creds auto-triggers elicitation before erroring."""
from unittest.mock import AsyncMock, MagicMock, patch
from conftest import make_tool_fn
from fastmcp import FastMCP
from unraid_mcp.core.exceptions import CredentialsNotConfiguredError
from unraid_mcp.tools.info import register_info_tool
test_mcp = FastMCP("test")
register_info_tool(test_mcp)
tool_fn = make_tool_fn("unraid_mcp.tools.info", "register_info_tool", "unraid_info")
mock_ctx = MagicMock()
# First call raises CredentialsNotConfiguredError, second returns data
call_count = 0
async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise CredentialsNotConfiguredError()
return {"info": {"os": {"hostname": "tootie"}}}
with (
patch("unraid_mcp.tools.info.make_graphql_request", side_effect=side_effect),
patch(
"unraid_mcp.tools.info.elicit_and_configure", new=AsyncMock(return_value=True)
) as mock_elicit,
):
result = await tool_fn(action="overview", ctx=mock_ctx)
mock_elicit.assert_called_once_with(mock_ctx)
assert result is not None

View File

@@ -6,14 +6,24 @@ system information, array status, network config, and server metadata.
from typing import Any, Literal, get_args from typing import Any, Literal, get_args
from fastmcp import Context as _Context
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import CredentialsNotConfiguredError as _CredErr
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError, tool_error_handler
from ..core.setup import elicit_and_configure as _elicit
from ..core.utils import format_kb from ..core.utils import format_kb
# Re-export at module scope so tests can patch "unraid_mcp.tools.info.elicit_and_configure"
# and "unraid_mcp.tools.info.CredentialsNotConfiguredError"
elicit_and_configure = _elicit
CredentialsNotConfiguredError = _CredErr
Context = _Context
# Pre-built queries keyed by action name # Pre-built queries keyed by action name
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
"overview": """ "overview": """
@@ -49,11 +59,9 @@ QUERIES: dict[str, str] = {
} }
""", """,
"network": """ "network": """
query GetNetworkConfig { query GetNetworkInfo {
network { servers { id name status wanip lanip localurl remoteurl }
id vars { id port portssl localTld useSsl }
accessUrls { type name ipv4 ipv6 }
}
} }
""", """,
"registration": """ "registration": """
@@ -86,7 +94,7 @@ QUERIES: dict[str, str] = {
""", """,
"metrics": """ "metrics": """
query GetMetrics { query GetMetrics {
metrics { cpu { percentTotal } memory { used total } } metrics { cpu { percentTotal } memory { total used free available buffcache percentTotal } }
} }
""", """,
"services": """ "services": """
@@ -130,12 +138,12 @@ QUERIES: dict[str, str] = {
""", """,
"servers": """ "servers": """
query GetServers { query GetServers {
servers { id name status comment wanip lanip localurl remoteurl } servers { id name status wanip lanip localurl remoteurl }
} }
""", """,
"flash": """ "flash": """
query GetFlash { query GetFlash {
flash { id guid product vendor } flash { id vendor product }
} }
""", """,
"ups_devices": """ "ups_devices": """
@@ -333,6 +341,7 @@ def register_info_tool(mcp: FastMCP) -> None:
sys_model: str | None = None, sys_model: str | None = None,
ssh_enabled: bool | None = None, ssh_enabled: bool | None = None,
ssh_port: int | None = None, ssh_port: int | None = None,
ctx: Context | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Query Unraid system information. """Query Unraid system information.
@@ -402,6 +411,13 @@ def register_info_tool(mcp: FastMCP) -> None:
"data": data.get("updateSshSettings"), "data": data.get("updateSshSettings"),
} }
# connect is not available on all Unraid API versions
if action == "connect":
raise ToolError(
"The 'connect' query is not available on this Unraid API version. "
"Use the 'settings' action for API and SSO configuration."
)
query = QUERIES[action] query = QUERIES[action]
variables: dict[str, Any] | None = None variables: dict[str, Any] | None = None
if action == "ups_device": if action == "ups_device":
@@ -410,9 +426,7 @@ def register_info_tool(mcp: FastMCP) -> None:
# Lookup tables for common response patterns # Lookup tables for common response patterns
# Simple dict actions: action -> GraphQL response key # Simple dict actions: action -> GraphQL response key
dict_actions: dict[str, str] = { dict_actions: dict[str, str] = {
"network": "network",
"registration": "registration", "registration": "registration",
"connect": "connect",
"variables": "vars", "variables": "vars",
"metrics": "metrics", "metrics": "metrics",
"config": "config", "config": "config",
@@ -430,6 +444,15 @@ def register_info_tool(mcp: FastMCP) -> None:
with tool_error_handler("info", action, logger): with tool_error_handler("info", action, logger):
logger.info(f"Executing unraid_info action={action}") logger.info(f"Executing unraid_info action={action}")
try:
data = await make_graphql_request(query, variables)
except CredentialsNotConfiguredError:
configured = await elicit_and_configure(ctx)
if not configured:
raise ToolError(
"Credentials required. Run `unraid_health action=setup` to configure."
)
# Retry once after successful elicitation
data = await make_graphql_request(query, variables) data = await make_graphql_request(query, variables)
# Special-case actions with custom processing # Special-case actions with custom processing
@@ -469,6 +492,27 @@ def register_info_tool(mcp: FastMCP) -> None:
if action == "server": if action == "server":
return data return data
if action == "network":
servers_data = data.get("servers") or []
vars_data = data.get("vars") or {}
access_urls = []
for srv in servers_data:
if srv.get("lanip"):
access_urls.append(
{"type": "LAN", "ipv4": srv["lanip"], "url": srv.get("localurl")}
)
if srv.get("wanip"):
access_urls.append(
{"type": "WAN", "ipv4": srv["wanip"], "url": srv.get("remoteurl")}
)
return {
"accessUrls": access_urls,
"httpPort": vars_data.get("port"),
"httpsPort": vars_data.get("portssl"),
"localTld": vars_data.get("localTld"),
"useSsl": vars_data.get("useSsl"),
}
# Simple dict-returning actions # Simple dict-returning actions
if action in dict_actions: if action in dict_actions:
return dict(data.get(dict_actions[action]) or {}) return dict(data.get(dict_actions[action]) or {})