mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 12:39:24 -07:00
feat(elicitation): auto-elicit credentials on CredentialsNotConfiguredError in unraid_info
This commit is contained in:
@@ -31,8 +31,12 @@ def make_tool_fn(
|
|||||||
|
|
||||||
This wraps the repeated pattern of creating a test FastMCP instance,
|
This wraps the repeated pattern of creating a test FastMCP instance,
|
||||||
registering a tool, and extracting the inner function. Centralizing
|
registering a tool, and extracting the inner function. Centralizing
|
||||||
this avoids reliance on FastMCP's private `_tool_manager._tools` API
|
this avoids reliance on FastMCP's internal tool storage API in every
|
||||||
in every test file.
|
test file.
|
||||||
|
|
||||||
|
FastMCP 3.x removed `_tool_manager._tools`; use `await mcp.get_tool()`
|
||||||
|
instead. We run a small event loop here to keep the helper synchronous
|
||||||
|
so callers don't need to change.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_path: Dotted import path to the tool module (e.g., "unraid_mcp.tools.info")
|
module_path: Dotted import path to the tool module (e.g., "unraid_mcp.tools.info")
|
||||||
@@ -48,4 +52,8 @@ def make_tool_fn(
|
|||||||
register_fn = getattr(module, register_fn_name)
|
register_fn = getattr(module, register_fn_name)
|
||||||
test_mcp = FastMCP("test")
|
test_mcp = FastMCP("test")
|
||||||
register_fn(test_mcp)
|
register_fn(test_mcp)
|
||||||
return test_mcp._tool_manager._tools[tool_name].fn # type: ignore[union-attr]
|
# FastMCP 3.x stores tools in providers[0]._components keyed as "tool:{name}@"
|
||||||
|
# (the "@" suffix is the version separator with no version set).
|
||||||
|
local_provider = test_mcp.providers[0]
|
||||||
|
tool = local_provider._components[f"tool:{tool_name}@"]
|
||||||
|
return tool.fn
|
||||||
|
|||||||
@@ -170,18 +170,37 @@ class TestUnraidInfoTool:
|
|||||||
await tool_fn(action="ups_device")
|
await tool_fn(action="ups_device")
|
||||||
|
|
||||||
async def test_network_action(self, _mock_graphql: AsyncMock) -> None:
|
async def test_network_action(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.return_value = {"network": {"id": "net:1", "accessUrls": []}}
|
|
||||||
tool_fn = _make_tool()
|
|
||||||
result = await tool_fn(action="network")
|
|
||||||
assert result["id"] == "net:1"
|
|
||||||
|
|
||||||
async def test_connect_action(self, _mock_graphql: AsyncMock) -> None:
|
|
||||||
_mock_graphql.return_value = {
|
_mock_graphql.return_value = {
|
||||||
"connect": {"status": "connected", "sandbox": False, "flashGuid": "abc123"}
|
"servers": [
|
||||||
|
{
|
||||||
|
"id": "s:1",
|
||||||
|
"name": "tootie",
|
||||||
|
"status": "ONLINE",
|
||||||
|
"lanip": "10.1.0.2",
|
||||||
|
"wanip": "",
|
||||||
|
"localurl": "http://10.1.0.2:6969",
|
||||||
|
"remoteurl": "",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"vars": {
|
||||||
|
"id": "v:1",
|
||||||
|
"port": 6969,
|
||||||
|
"portssl": 31337,
|
||||||
|
"localTld": "local",
|
||||||
|
"useSsl": None,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="connect")
|
result = await tool_fn(action="network")
|
||||||
assert result["status"] == "connected"
|
assert "accessUrls" in result
|
||||||
|
assert result["httpPort"] == 6969
|
||||||
|
assert result["httpsPort"] == 31337
|
||||||
|
assert any(u["type"] == "LAN" and u["ipv4"] == "10.1.0.2" for u in result["accessUrls"])
|
||||||
|
|
||||||
|
async def test_connect_action_raises_tool_error(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="connect.*not available"):
|
||||||
|
await tool_fn(action="connect")
|
||||||
|
|
||||||
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("unexpected")
|
_mock_graphql.side_effect = RuntimeError("unexpected")
|
||||||
|
|||||||
@@ -174,3 +174,42 @@ async def test_make_graphql_request_raises_sentinel_when_unconfigured():
|
|||||||
finally:
|
finally:
|
||||||
settings_mod.UNRAID_API_URL = original_url
|
settings_mod.UNRAID_API_URL = original_url
|
||||||
settings_mod.UNRAID_API_KEY = original_key
|
settings_mod.UNRAID_API_KEY = original_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_elicitation_triggered_on_credentials_not_configured():
|
||||||
|
"""Any tool call with missing creds auto-triggers elicitation before erroring."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from conftest import make_tool_fn
|
||||||
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
|
from unraid_mcp.core.exceptions import CredentialsNotConfiguredError
|
||||||
|
from unraid_mcp.tools.info import register_info_tool
|
||||||
|
|
||||||
|
test_mcp = FastMCP("test")
|
||||||
|
register_info_tool(test_mcp)
|
||||||
|
tool_fn = make_tool_fn("unraid_mcp.tools.info", "register_info_tool", "unraid_info")
|
||||||
|
|
||||||
|
mock_ctx = MagicMock()
|
||||||
|
|
||||||
|
# First call raises CredentialsNotConfiguredError, second returns data
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def side_effect(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise CredentialsNotConfiguredError()
|
||||||
|
return {"info": {"os": {"hostname": "tootie"}}}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.tools.info.make_graphql_request", side_effect=side_effect),
|
||||||
|
patch(
|
||||||
|
"unraid_mcp.tools.info.elicit_and_configure", new=AsyncMock(return_value=True)
|
||||||
|
) as mock_elicit,
|
||||||
|
):
|
||||||
|
result = await tool_fn(action="overview", ctx=mock_ctx)
|
||||||
|
|
||||||
|
mock_elicit.assert_called_once_with(mock_ctx)
|
||||||
|
assert result is not None
|
||||||
|
|||||||
@@ -6,14 +6,24 @@ system information, array status, network config, and server metadata.
|
|||||||
|
|
||||||
from typing import Any, Literal, get_args
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
|
from fastmcp import Context as _Context
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
|
from ..core.exceptions import CredentialsNotConfiguredError as _CredErr
|
||||||
from ..core.exceptions import ToolError, tool_error_handler
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
from ..core.setup import elicit_and_configure as _elicit
|
||||||
from ..core.utils import format_kb
|
from ..core.utils import format_kb
|
||||||
|
|
||||||
|
|
||||||
|
# Re-export at module scope so tests can patch "unraid_mcp.tools.info.elicit_and_configure"
|
||||||
|
# and "unraid_mcp.tools.info.CredentialsNotConfiguredError"
|
||||||
|
elicit_and_configure = _elicit
|
||||||
|
CredentialsNotConfiguredError = _CredErr
|
||||||
|
Context = _Context
|
||||||
|
|
||||||
|
|
||||||
# Pre-built queries keyed by action name
|
# Pre-built queries keyed by action name
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
"overview": """
|
"overview": """
|
||||||
@@ -49,11 +59,9 @@ QUERIES: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"network": """
|
"network": """
|
||||||
query GetNetworkConfig {
|
query GetNetworkInfo {
|
||||||
network {
|
servers { id name status wanip lanip localurl remoteurl }
|
||||||
id
|
vars { id port portssl localTld useSsl }
|
||||||
accessUrls { type name ipv4 ipv6 }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"registration": """
|
"registration": """
|
||||||
@@ -86,7 +94,7 @@ QUERIES: dict[str, str] = {
|
|||||||
""",
|
""",
|
||||||
"metrics": """
|
"metrics": """
|
||||||
query GetMetrics {
|
query GetMetrics {
|
||||||
metrics { cpu { percentTotal } memory { used total } }
|
metrics { cpu { percentTotal } memory { total used free available buffcache percentTotal } }
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"services": """
|
"services": """
|
||||||
@@ -130,12 +138,12 @@ QUERIES: dict[str, str] = {
|
|||||||
""",
|
""",
|
||||||
"servers": """
|
"servers": """
|
||||||
query GetServers {
|
query GetServers {
|
||||||
servers { id name status comment wanip lanip localurl remoteurl }
|
servers { id name status wanip lanip localurl remoteurl }
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"flash": """
|
"flash": """
|
||||||
query GetFlash {
|
query GetFlash {
|
||||||
flash { id guid product vendor }
|
flash { id vendor product }
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"ups_devices": """
|
"ups_devices": """
|
||||||
@@ -333,6 +341,7 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
sys_model: str | None = None,
|
sys_model: str | None = None,
|
||||||
ssh_enabled: bool | None = None,
|
ssh_enabled: bool | None = None,
|
||||||
ssh_port: int | None = None,
|
ssh_port: int | None = None,
|
||||||
|
ctx: Context | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Query Unraid system information.
|
"""Query Unraid system information.
|
||||||
|
|
||||||
@@ -402,6 +411,13 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
"data": data.get("updateSshSettings"),
|
"data": data.get("updateSshSettings"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# connect is not available on all Unraid API versions
|
||||||
|
if action == "connect":
|
||||||
|
raise ToolError(
|
||||||
|
"The 'connect' query is not available on this Unraid API version. "
|
||||||
|
"Use the 'settings' action for API and SSO configuration."
|
||||||
|
)
|
||||||
|
|
||||||
query = QUERIES[action]
|
query = QUERIES[action]
|
||||||
variables: dict[str, Any] | None = None
|
variables: dict[str, Any] | None = None
|
||||||
if action == "ups_device":
|
if action == "ups_device":
|
||||||
@@ -410,9 +426,7 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
# Lookup tables for common response patterns
|
# Lookup tables for common response patterns
|
||||||
# Simple dict actions: action -> GraphQL response key
|
# Simple dict actions: action -> GraphQL response key
|
||||||
dict_actions: dict[str, str] = {
|
dict_actions: dict[str, str] = {
|
||||||
"network": "network",
|
|
||||||
"registration": "registration",
|
"registration": "registration",
|
||||||
"connect": "connect",
|
|
||||||
"variables": "vars",
|
"variables": "vars",
|
||||||
"metrics": "metrics",
|
"metrics": "metrics",
|
||||||
"config": "config",
|
"config": "config",
|
||||||
@@ -430,7 +444,16 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
with tool_error_handler("info", action, logger):
|
with tool_error_handler("info", action, logger):
|
||||||
logger.info(f"Executing unraid_info action={action}")
|
logger.info(f"Executing unraid_info action={action}")
|
||||||
data = await make_graphql_request(query, variables)
|
try:
|
||||||
|
data = await make_graphql_request(query, variables)
|
||||||
|
except CredentialsNotConfiguredError:
|
||||||
|
configured = await elicit_and_configure(ctx)
|
||||||
|
if not configured:
|
||||||
|
raise ToolError(
|
||||||
|
"Credentials required. Run `unraid_health action=setup` to configure."
|
||||||
|
)
|
||||||
|
# Retry once after successful elicitation
|
||||||
|
data = await make_graphql_request(query, variables)
|
||||||
|
|
||||||
# Special-case actions with custom processing
|
# Special-case actions with custom processing
|
||||||
if action == "overview":
|
if action == "overview":
|
||||||
@@ -469,6 +492,27 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
if action == "server":
|
if action == "server":
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
if action == "network":
|
||||||
|
servers_data = data.get("servers") or []
|
||||||
|
vars_data = data.get("vars") or {}
|
||||||
|
access_urls = []
|
||||||
|
for srv in servers_data:
|
||||||
|
if srv.get("lanip"):
|
||||||
|
access_urls.append(
|
||||||
|
{"type": "LAN", "ipv4": srv["lanip"], "url": srv.get("localurl")}
|
||||||
|
)
|
||||||
|
if srv.get("wanip"):
|
||||||
|
access_urls.append(
|
||||||
|
{"type": "WAN", "ipv4": srv["wanip"], "url": srv.get("remoteurl")}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"accessUrls": access_urls,
|
||||||
|
"httpPort": vars_data.get("port"),
|
||||||
|
"httpsPort": vars_data.get("portssl"),
|
||||||
|
"localTld": vars_data.get("localTld"),
|
||||||
|
"useSsl": vars_data.get("useSsl"),
|
||||||
|
}
|
||||||
|
|
||||||
# Simple dict-returning actions
|
# Simple dict-returning actions
|
||||||
if action in dict_actions:
|
if action in dict_actions:
|
||||||
return dict(data.get(dict_actions[action]) or {})
|
return dict(data.get(dict_actions[action]) or {})
|
||||||
|
|||||||
Reference in New Issue
Block a user