feat: add API key bearer token authentication

- ApiKeyVerifier(TokenVerifier) — validates Authorization: Bearer <key>
  against UNRAID_MCP_API_KEY; guards against empty-key bypass
- _build_auth() replaces module-level _build_google_auth() call:
  returns MultiAuth(server=google, verifiers=[api_key]) when both set,
  GoogleProvider alone, ApiKeyVerifier alone, or None
- settings.py: add UNRAID_MCP_API_KEY + is_api_key_auth_configured()
  + api_key_auth_enabled in get_config_summary()
- run_server(): improved auth status logging for all three states
- tests/test_api_key_auth.py: 9 tests covering verifier + _build_auth
- .env.example: add UNRAID_MCP_API_KEY section
- docs/GOOGLE_OAUTH.md: add API Key section
- README.md / CLAUDE.md: rename section, document both auth methods
- Fix pre-existing: test_health.py patched cache_middleware/error_middleware
  now match renamed _cache_middleware/_error_middleware in server.py
This commit is contained in:
Jacob Magar
2026-03-16 11:11:38 -04:00
parent 6f7a58a0f9
commit cc24f1ec62
16 changed files with 406 additions and 69 deletions

View File

@@ -98,6 +98,19 @@ def is_google_auth_configured() -> bool:
return bool(GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET and UNRAID_MCP_BASE_URL)
# API Key Authentication (Optional)
# ----------------------------------
# A static bearer token clients can use instead of (or alongside) Google OAuth.
# Can be set to the same value as UNRAID_API_KEY for simplicity, or a separate
# dedicated secret for MCP access.
UNRAID_MCP_API_KEY = os.getenv("UNRAID_MCP_API_KEY", "")
def is_api_key_auth_configured() -> bool:
"""Return True when UNRAID_MCP_API_KEY is set."""
return bool(UNRAID_MCP_API_KEY)
# Logging Configuration
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
@@ -180,6 +193,7 @@ def get_config_summary() -> dict[str, Any]:
"google_auth_enabled": is_google_auth_configured(),
"google_auth_base_url": UNRAID_MCP_BASE_URL if is_google_auth_configured() else None,
"jwt_signing_key_configured": bool(UNRAID_MCP_JWT_SIGNING_KEY),
"api_key_auth_enabled": is_api_key_auth_configured(),
}

View File

@@ -8,6 +8,7 @@ import sys
from typing import Any
from fastmcp import FastMCP
from fastmcp.server.auth import AccessToken, MultiAuth, TokenVerifier
from fastmcp.server.auth.providers.google import GoogleProvider
from fastmcp.server.middleware.caching import CallToolSettings, ResponseCachingMiddleware
from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware
@@ -41,26 +42,32 @@ _logging_middleware = LoggingMiddleware(
# 2. Catch any unhandled exceptions and convert to proper MCP errors.
# Tracks error_counts per (exception_type:method) for health diagnose.
error_middleware = ErrorHandlingMiddleware(
_error_middleware = ErrorHandlingMiddleware(
logger=logger,
include_traceback=True,
)
# 3. Unraid API rate limit: 100 requests per 10 seconds.
# Use a sliding window that stays comfortably under that cap.
_rate_limiter = SlidingWindowRateLimitingMiddleware(max_requests=90, window_minutes=1)
# SlidingWindowRateLimitingMiddleware only accepts window_minutes (int), so express
# the 10-second budget as a 1-minute equivalent: 540 req/60 s to stay comfortably
# under the 600 req/min ceiling.
_rate_limiter = SlidingWindowRateLimitingMiddleware(max_requests=540, window_minutes=1)
# 4. Cap tool responses at 512 KB to protect the client context window.
# Oversized responses are truncated with a clear suffix rather than erroring.
_response_limiter = ResponseLimitingMiddleware(max_size=512_000)
# 5. Cache tool calls in-memory (MemoryStore default — no extra deps).
# Short 30 s TTL absorbs burst duplicate requests while keeping data fresh.
# Destructive calls won't hit the cache in practice (unique confirm=True + IDs).
cache_middleware = ResponseCachingMiddleware(
# 5. Cache middleware — all call_tool caching is disabled for the `unraid` tool.
# CallToolSettings supports excluded_tools/included_tools by tool name only; there
# is no per-argument or per-subaction exclusion mechanism. The cache key is
# "{tool_name}:{arguments_str}", so a cached stop("nginx") result would be served
# back on a retry within the TTL window even though the container is already stopped.
# Mutation subactions (start, stop, restart, reboot, etc.) must never be cached.
# Because the consolidated `unraid` tool mixes reads and mutations under one name,
# the only safe option is to disable caching for the entire tool.
_cache_middleware = ResponseCachingMiddleware(
call_tool_settings=CallToolSettings(
ttl=30,
included_tools=["unraid"],
enabled=False,
),
# Disable caching for list/resource/prompt — those are cheap.
list_tools_settings={"enabled": False},
@@ -71,6 +78,30 @@ cache_middleware = ResponseCachingMiddleware(
)
class ApiKeyVerifier(TokenVerifier):
"""Bearer token verifier that validates against a static API key.
Clients present the key as a standard OAuth bearer token:
Authorization: Bearer <UNRAID_MCP_API_KEY>
This allows machine-to-machine access (e.g. CI, scripts, other agents)
without going through the Google OAuth browser flow.
"""
def __init__(self, api_key: str) -> None:
super().__init__()
self._api_key = api_key
async def verify_token(self, token: str) -> AccessToken | None:
if self._api_key and token == self._api_key:
return AccessToken(
token=token,
client_id="api-key-client",
scopes=[],
)
return None
def _build_google_auth() -> "GoogleProvider | None":
"""Build GoogleProvider when OAuth env vars are configured, else return None.
@@ -117,21 +148,45 @@ def _build_google_auth() -> "GoogleProvider | None":
return GoogleProvider(**kwargs)
# Build auth provider — returns GoogleProvider when configured, None otherwise.
_google_auth = _build_google_auth()
def _build_auth() -> "GoogleProvider | ApiKeyVerifier | MultiAuth | None":
"""Build the active auth stack from environment configuration.
Returns:
- MultiAuth(server=GoogleProvider, verifiers=[ApiKeyVerifier])
when both GOOGLE_CLIENT_ID and UNRAID_MCP_API_KEY are set.
- GoogleProvider alone when only Google OAuth vars are set.
- ApiKeyVerifier alone when only UNRAID_MCP_API_KEY is set.
- None when no auth vars are configured (open server).
"""
from .config.settings import UNRAID_MCP_API_KEY, is_api_key_auth_configured
google = _build_google_auth()
api_key = ApiKeyVerifier(UNRAID_MCP_API_KEY) if is_api_key_auth_configured() else None
if google and api_key:
logger.info("Auth: Google OAuth + API key both enabled (MultiAuth)")
return MultiAuth(server=google, verifiers=[api_key])
if api_key:
logger.info("Auth: API key authentication enabled")
return api_key
return google # GoogleProvider or None
# Build auth stack — GoogleProvider, ApiKeyVerifier, MultiAuth, or None.
_auth = _build_auth()
# Initialize FastMCP instance
mcp = FastMCP(
name="Unraid MCP Server",
instructions="Provides tools to interact with an Unraid server's GraphQL API.",
version=VERSION,
auth=_google_auth,
auth=_auth,
middleware=[
_logging_middleware,
error_middleware,
_error_middleware,
_rate_limiter,
_response_limiter,
cache_middleware,
_cache_middleware,
],
)
@@ -185,17 +240,25 @@ def run_server() -> None:
"Only use this in trusted networks or for development."
)
if _google_auth is not None:
from .config.settings import UNRAID_MCP_BASE_URL
if _auth is not None:
from .config.settings import is_google_auth_configured
logger.info(
"Google OAuth ENABLED — clients must authenticate before calling tools. "
f"Redirect URI: {UNRAID_MCP_BASE_URL}/auth/callback"
)
if is_google_auth_configured():
from .config.settings import UNRAID_MCP_BASE_URL
logger.info(
"Google OAuth ENABLED — clients must authenticate before calling tools. "
f"Redirect URI: {UNRAID_MCP_BASE_URL}/auth/callback"
)
else:
logger.info(
"API key authentication ENABLED — present UNRAID_MCP_API_KEY as bearer token."
)
else:
logger.warning(
"No authentication configured — MCP server is open to all clients on the network. "
"Set GOOGLE_CLIENT_ID + GOOGLE_CLIENT_SECRET + UNRAID_MCP_BASE_URL to enable OAuth."
"Set GOOGLE_CLIENT_ID + GOOGLE_CLIENT_SECRET + UNRAID_MCP_BASE_URL to enable Google OAuth, "
"or set UNRAID_MCP_API_KEY to enable bearer token authentication."
)
logger.info(

View File

@@ -285,6 +285,16 @@ async def _handle_system(subaction: str, device_id: str | None) -> dict[str, Any
# ===========================================================================
_HEALTH_SUBACTIONS: set[str] = {"check", "test_connection", "diagnose", "setup"}
_HEALTH_QUERIES: dict[str, str] = {
"comprehensive_health": (
"query ComprehensiveHealthCheck {"
" info { machineId time versions { core { unraid } } os { uptime } }"
" array { state }"
" notifications { overview { unread { alert warning total } } }"
" docker { containers(skipCache: true) { id state status } }"
" }"
),
}
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
@@ -346,7 +356,8 @@ async def _handle_health(subaction: str, ctx: Context | None) -> dict[str, Any]
return await _comprehensive_health_check()
if subaction == "diagnose":
from ..server import cache_middleware, error_middleware
from ..server import _cache_middleware as cache_middleware
from ..server import _error_middleware as error_middleware
from ..subscriptions.manager import subscription_manager
from ..subscriptions.resources import ensure_subscriptions_started
@@ -373,7 +384,7 @@ async def _handle_health(subaction: str, ctx: Context | None) -> dict[str, Any]
"call_tool": {
"hits": cache_stats.call_tool.get.hit,
"misses": cache_stats.call_tool.get.miss,
"puts": cache_stats.call_tool.put.total,
"puts": cache_stats.call_tool.put.count,
}
if cache_stats.call_tool
else {"hits": 0, "misses": 0, "puts": 0},
@@ -403,15 +414,7 @@ async def _comprehensive_health_check() -> dict[str, Any]:
health_severity = max(health_severity, _SEVERITY.get(level, 0))
try:
query = """
query ComprehensiveHealthCheck {
info { machineId time versions { core { unraid } } os { uptime } }
array { state }
notifications { overview { unread { alert warning total } } }
docker { containers(skipCache: true) { id state status } }
}
"""
data = await make_graphql_request(query)
data = await make_graphql_request(_HEALTH_QUERIES["comprehensive_health"])
api_latency = round((time.time() - start_time) * 1000, 2)
health_info: dict[str, Any] = {
@@ -738,9 +741,13 @@ _DOCKER_QUERIES: dict[str, str] = {
"details": "query GetContainerDetails { docker { containers(skipCache: false) { id names image imageId command created ports { ip privatePort publicPort type } sizeRootFs labels state status hostConfig { networkMode } networkSettings mounts autoStart } } }",
"networks": "query GetDockerNetworks { docker { networks { id name driver scope } } }",
"network_details": "query GetDockerNetwork { docker { networks { id name driver scope enableIPv6 internal attachable containers options labels } } }",
"_resolve": "query ResolveContainerID { docker { containers(skipCache: true) { id names } } }",
}
# Internal query used only for container ID resolution — not a public subaction.
_DOCKER_RESOLVE_QUERY = (
"query ResolveContainerID { docker { containers(skipCache: true) { id names } } }"
)
_DOCKER_MUTATIONS: dict[str, str] = {
"start": "mutation StartContainer($id: PrefixedID!) { docker { start(id: $id) { id names state status } } }",
"stop": "mutation StopContainer($id: PrefixedID!) { docker { stop(id: $id) { id names state status } } }",
@@ -775,7 +782,7 @@ def _find_container(
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
if _DOCKER_ID_PATTERN.match(container_id):
return container_id
data = await make_graphql_request(_DOCKER_QUERIES["_resolve"])
data = await make_graphql_request(_DOCKER_RESOLVE_QUERY)
containers = safe_get(data, "docker", "containers", default=[])
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
id_lower = container_id.lower()
@@ -1640,7 +1647,7 @@ async def _handle_live(
if subaction == "log_tail":
if not path:
raise ToolError("path is required for live/log_tail")
normalized = os.path.realpath(path) # noqa: ASYNC240
normalized = await asyncio.to_thread(os.path.realpath, path)
if not any(normalized.startswith(p) for p in _LIVE_ALLOWED_LOG_PREFIXES):
raise ToolError(f"path must start with one of: {', '.join(_LIVE_ALLOWED_LOG_PREFIXES)}")
path = normalized