4 Commits

Author SHA1 Message Date
3acdfa3b2b release: version 0.0.2 🚀
All checks were successful
Build Docker image / Create Release (push) Successful in 9s
Build Docker image / deploy (push) Successful in 1m1s
2026-02-21 14:34:25 +01:00
da669a49cb fix: set scripts as executable, refs NOISSUE 2026-02-21 14:34:03 +01:00
d7f7ac7a6f fix: initial test release 2026-02-21 14:02:18 +01:00
e023d89308 ci: add Homelab CI framework 2026-02-21 13:51:09 +01:00
51 changed files with 1471 additions and 2090 deletions

View File

@@ -0,0 +1,50 @@
#!/usr/bin/env sh
echo "Running commit message checks..."
. "$(dirname -- "$0")/../../.gitea/conventional_commits/hooks/text-styles.sh"
# Get the commit message
commit="$(cat .git/COMMIT_EDITMSG)"
# Define the conventional commit regex
regex='^((build|chore|ci|docs|feat|fix|perf|refactor|revert|style|test)(\(.+\))?(!?):\s([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)(\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(NOISSUE)))?))|(release: .*)$'
# Check if the commit message matches the conventional commit format
if ! echo "$commit" | grep -Pq "$regex"
then
echo
colorPrint red "❌ Failed to create commit. Your commit message does not follow the conventional commit format."
colorPrint red "Please use the following format: $(colorPrint brightRed 'type(scope)?: description')"
colorPrint red "Available types are listed below. Scope is optional. Use ! after type to indicate breaking change."
echo
colorPrint brightWhite "Quick examples:"
echo "feat: add email notifications on new direct messages refs ABC-1213"
echo "feat(shopping cart): add the amazing button ref: DEFG-23"
echo "feat!: remove ticket list endpoint ref DADA-109"
echo "fix(api): handle empty message in request body refs: MINE-82"
echo "chore(deps): bump some-package-name to version 2.0.0 refs ASDF-12"
echo
colorPrint brightWhite "Commit types:"
colorPrint brightCyan "build: $(colorPrint white "Changes that affect the build system or external dependencies (example scopes: gulp, broccoli, npm)" -n)"
colorPrint brightCyan "ci: $(colorPrint white "Changes to CI configuration files and scripts (example scopes: Travis, Circle, BrowserStack, SauceLabs)" -n)"
colorPrint brightCyan "chore: $(colorPrint white "Changes which doesn't change source code or tests e.g. changes to the build process, auxiliary tools, libraries" -n)"
colorPrint brightCyan "docs: $(colorPrint white "Documentation only changes" -n)"
colorPrint brightCyan "feat: $(colorPrint white "A new feature" -n)"
colorPrint brightCyan "fix: $(colorPrint white "A bug fix" -n)"
colorPrint brightCyan "perf: $(colorPrint white "A code change that improves performance" -n)"
colorPrint brightCyan "refactor: $(colorPrint white "A code change that neither fixes a bug nor adds a feature" -n)"
colorPrint brightCyan "revert: $(colorPrint white "Revert a change previously introduced" -n)"
colorPrint brightCyan "test: $(colorPrint white "Adding missing tests or correcting existing tests" -n)"
echo
colorPrint brightWhite "Reminders"
echo "Put newline before extended commit body"
echo "More details at $(underline "http://www.conventionalcommits.org")"
echo
echo "The commit message you attempted was: $commit"
echo
echo "The exact RegEx applied to this message was:"
colorPrint brightCyan "$regex"
echo
exit 1
fi

View File

@@ -0,0 +1,106 @@
#!/bin/bash
# Rules for generating semantic versioning
# major: breaking change
# minor: feat, style
# patch: build, fix, perf, refactor, revert
PREVENT_REMOVE_FILE=$1
TEMP_FILE_PATH=.gitea/conventional_commits/tmp
LAST_TAG=$(git describe --tags --abbrev=0 --always)
echo "Last tag: #$LAST_TAG#"
PATTERN="^[0-9]+\.[0-9]+\.[0-9]+$"
increment_version() {
local version=$1
local increment=$2
local major=$(echo $version | cut -d. -f1)
local minor=$(echo $version | cut -d. -f2)
local patch=$(echo $version | cut -d. -f3)
if [ "$increment" == "major" ]; then
major=$((major + 1))
minor=0
patch=0
elif [ "$increment" == "minor" ]; then
minor=$((minor + 1))
patch=0
elif [ "$increment" == "patch" ]; then
patch=$((patch + 1))
fi
echo "${major}.${minor}.${patch}"
}
create_file() {
local with_range=$1
if [ -s $TEMP_FILE_PATH/messages.txt ]; then
return 1
fi
if [ "$with_range" == "true" ]; then
git log $LAST_TAG..HEAD --no-decorate --pretty=format:"%s" > $TEMP_FILE_PATH/messages.txt
else
git log --no-decorate --pretty=format:"%s" > $TEMP_FILE_PATH/messages.txt
fi
}
get_commit_range() {
rm -f $TEMP_FILE_PATH/messages.txt
if [[ $LAST_TAG =~ $PATTERN ]]; then
create_file true
else
create_file
LAST_TAG="0.0.0"
fi
echo " " >> $TEMP_FILE_PATH/messages.txt
}
start() {
mkdir -p $TEMP_FILE_PATH
get_commit_range
new_version=$LAST_TAG
increment_type=""
while read message; do
echo $message
if echo $message | grep -Pq '(feat|style)(\([\w]+\))?!:([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)(\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(#[0-9]+)|(NOISSUE)))?)'; then
increment_type="major"
echo "a"
break
elif echo $message | grep -Pq '(feat|style)(\([\w]+\))?:([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)(\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(#[0-9]+)|(NOISSUE)))?)'; then
if [ -z "$increment_type" ] || [ "$increment_type" == "patch" ]; then
increment_type="minor"
echo "b"
fi
elif echo $message | grep -Pq '(build|fix|perf|refactor|revert)(\(.+\))?:\s([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)(\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(#[0-9]+)|(NOISSUE)))?)'; then
if [ -z "$increment_type" ]; then
increment_type="patch"
echo "c"
fi
fi
done < $TEMP_FILE_PATH/messages.txt
if [ -n "$increment_type" ]; then
new_version=$(increment_version $LAST_TAG $increment_type)
echo "New version: $new_version"
gitchangelog | grep -v "[rR]elease:" > HISTORY.md
git add HISTORY.md
echo $new_version > VERSION
git add VERSION
git commit -m "release: version $new_version 🚀"
echo "creating git tag : $new_version"
git tag $new_version
git push -u origin HEAD --tags
echo "Gitea Actions will detect the new tag and release the new version."
else
echo "No changes requiring a version increment."
fi
}
start
if [ -z "$PREVENT_REMOVE_FILE" ]; then
rm -f $TEMP_FILE_PATH/messages.txt
fi

View File

@@ -0,0 +1,44 @@
#!/bin/sh
colorPrint() {
local color=$1
local text=$2
shift 2
local newline="\n"
local tab=""
for arg in "$@"
do
if [ "$arg" = "-t" ]; then
tab="\t"
elif [ "$arg" = "-n" ]; then
newline=""
fi
done
case $color in
black) color_code="30" ;;
red) color_code="31" ;;
green) color_code="32" ;;
yellow) color_code="33" ;;
blue) color_code="34" ;;
magenta) color_code="35" ;;
cyan) color_code="36" ;;
white) color_code="37" ;;
brightBlack) color_code="90" ;;
brightRed) color_code="91" ;;
brightGreen) color_code="92" ;;
brightYellow) color_code="93" ;;
brightBlue) color_code="94" ;;
brightMagenta) color_code="95" ;;
brightCyan) color_code="96" ;;
brightWhite) color_code="97" ;;
*) echo "Invalid color"; return ;;
esac
printf "\e[${color_code}m${tab}%s\e[0m${newline}" "$text"
}
underline () {
printf "\033[4m%s\033[24m" "$1"
}

4
.gitea/release_message.sh Executable file
View File

@@ -0,0 +1,4 @@
#!/usr/bin/env bash
# generates changelog since last release
previous_tag=$(git tag --sort=-creatordate | sed -n 2p)
git shortlog "${previous_tag}.." | sed 's/^./ &/'

View File

@@ -0,0 +1,61 @@
name: Build Docker image
permissions:
contents: write
env:
SKIP_MAKE_SETUP_CHECK: 'true'
on:
push:
# Sequence of patterns matched against refs/tags
tags:
- '*' # Push events to matching v*, i.e. v1.0, v20.15.10
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
jobs:
release:
name: Create Release
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v5
with:
# by default, it uses a depth of 1
# this fetches all history so that we can read each commit
fetch-depth: 0
- name: Generate Changelog
run: .gitea/release_message.sh > release_message.md
- name: Release
uses: softprops/action-gh-release@v1
with:
body_path: release_message.md
deploy:
needs: release
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- name: Check version match
run: |
REPOSITORY_NAME=$(echo "$GITHUB_REPOSITORY" | awk -F '/' '{print $2}' | tr '-' '_')
if [ "$(cat VERSION)" = "${GITHUB_REF_NAME}" ] ; then
echo "Version matches successfully!"
else
echo "Version must match!"
exit -1
fi
- name: Login to Gitea container registry
uses: docker/login-action@v3
with:
username: gitearobot
password: ${{ secrets.PACKAGE_GITEA_PAT }}
registry: git.disi.dev
- name: Build and publish
run: |
REPOSITORY_OWNER=$(echo "$GITHUB_REPOSITORY" | awk -F '/' '{print $1}' | tr '[:upper:]' '[:lower:]')
REPOSITORY_NAME=$(echo "$GITHUB_REPOSITORY" | awk -F '/' '{print $2}' | tr '-' '_')
docker build -t "git.disi.dev/$REPOSITORY_OWNER/unraid-mcp:$(cat VERSION)" ./
docker push "git.disi.dev/$REPOSITORY_OWNER/unraid-mcp:$(cat VERSION)"

View File

@@ -1,5 +1,5 @@
# Use an official Python runtime as a parent image # Use an official Python runtime as a parent image
FROM python:3.12-slim FROM python:3.11-slim
# Set the working directory in the container # Set the working directory in the container
WORKDIR /app WORKDIR /app
@@ -7,22 +7,14 @@ WORKDIR /app
# Install uv # Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
# Create non-root user with home directory and give ownership of /app # Copy dependency files
RUN groupadd --gid 1000 appuser && \ COPY pyproject.toml .
useradd --uid 1000 --gid 1000 --create-home --shell /bin/false appuser && \ COPY uv.lock .
chown appuser:appuser /app COPY README.md .
COPY LICENSE .
# Copy dependency files (owned by appuser via --chown)
COPY --chown=appuser:appuser pyproject.toml .
COPY --chown=appuser:appuser uv.lock .
COPY --chown=appuser:appuser README.md .
COPY --chown=appuser:appuser LICENSE .
# Copy the source code # Copy the source code
COPY --chown=appuser:appuser unraid_mcp/ ./unraid_mcp/ COPY unraid_mcp/ ./unraid_mcp/
# Switch to non-root user before installing dependencies
USER appuser
# Install dependencies and the package # Install dependencies and the package
RUN uv sync --frozen RUN uv sync --frozen
@@ -40,9 +32,5 @@ ENV UNRAID_API_KEY=""
ENV UNRAID_VERIFY_SSL="true" ENV UNRAID_VERIFY_SSL="true"
ENV UNRAID_MCP_LOG_LEVEL="INFO" ENV UNRAID_MCP_LOG_LEVEL="INFO"
# Health check # Run unraid-mcp-server.py when the container launches
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
CMD ["python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:6970/mcp')"]
# Run unraid-mcp-server when the container launches
CMD ["uv", "run", "unraid-mcp-server"] CMD ["uv", "run", "unraid-mcp-server"]

459
HISTORY.md Normal file
View File

@@ -0,0 +1,459 @@
Changelog
=========
(unreleased)
------------
Fix
~~~
- Set scripts as executable, refs NOISSUE. [Simon Diesenreiter]
- Initial test release. [Simon Diesenreiter]
0.0.1 (2026-02-21)
------------------
Fix
~~~
- Use CLAUDE_PLUGIN_ROOT for portable MCP server configuration. [Jacob
Magar]
Update .mcp.json to use environment variable
for the --directory argument, ensuring the MCP server works correctly
regardless of where the plugin is installed.
This follows Claude Code plugin best practices for MCP server bundling.
- Correct marketplace.json source field format. [Jacob Magar]
Change source from absolute GitHub URL to relative path "./"
This follows Claude Code marketplace convention where source paths
are relative to the cloned repository root, not external URLs.
Matches pattern from working examples like claude-homelab marketplace.
- Upgrade fastmcp and mcp to resolve remaining security vulnerabilities.
[Claude]
Security Updates:
- fastmcp 2.12.5 → 2.14.5 (fixes CVE-2025-66416, command injection, XSS, auth takeover)
- mcp 1.16.0 → 1.26.0 (enables DNS rebinding protection, addresses CVE requirements)
- websockets 13.1 → 16.0 (required dependency for fastmcp 2.14.5)
Dependency Changes:
+ beartype 0.22.9
+ cachetools 7.0.1
+ cloudpickle 3.1.2
+ croniter 6.0.0
+ diskcache 5.6.3
+ fakeredis 2.34.0
+ importlib-metadata 8.7.1
+ jsonref 1.1.1
+ lupa 2.6
+ opentelemetry-api 1.39.1
+ pathvalidate 3.3.1
+ platformdirs 4.9.2
+ prometheus-client 0.24.1
+ py-key-value-aio 0.3.0
+ py-key-value-shared 0.3.0
+ pydocket 0.17.7
+ pyjwt 2.11.0
+ python-dateutil 2.9.0.post0
+ python-json-logger 4.0.0
+ redis 7.2.0
+ shellingham 1.5.4
+ sortedcontainers 2.4.0
+ typer 0.23.2
+ zipp 3.23.0
Removed Dependencies:
- isodate 0.7.2
- lazy-object-proxy 1.12.0
- markupsafe 3.0.3
- openapi-core 0.22.0
- openapi-schema-validator 0.6.3
- openapi-spec-validator 0.7.2
- rfc3339-validator 0.1.4
- werkzeug 3.1.5
Testing:
- All 493 tests pass
- Type checking passes (ty check)
- Linting passes (ruff check)
This completes the resolution of GitHub Dependabot security alerts.
Addresses the remaining 5 high/medium severity vulnerabilities in fastmcp and mcp packages.
- Correct marketplace.json source field and improve async operations.
[Jacob Magar]
- Fix marketplace.json: change source from relative path to GitHub URL
(was "skills/unraid", now "https://github.com/jmagar/unraid-mcp")
This resolves the "Invalid input" schema validation error when adding
the marketplace to Claude Code
- Refactor subscriptions autostart to use anyio.Path for async file checks
(replaces blocking pathlib.Path.exists() with async anyio.Path.exists())
- Update dependencies: anyio 4.11.0→4.12.1, attrs 25.3.0→25.4.0
- Correct marketplace.json format for Claude Code compatibility.
[Claude]
- Rename marketplace from "unraid-mcp" to "jmagar-unraid-mcp" to match expected directory structure
- Wrap description, version, homepage, and repository in metadata object per standard format
- Fixes "Marketplace file not found" error when adding marketplace to Claude Code
Resolves marketplace installation issues by aligning with format used by other Claude Code marketplaces.
- Address PR comment #38 - remove duplicate User-Agent header. [Jacob
Magar]
Resolves review thread PRRT_kwDOO6Hdxs5uu7z7
- Removed redundant User-Agent header from per-request headers in make_graphql_request()
- User-Agent is already set as default header on the shared HTTP client
- httpx merges per-request headers with client defaults, so client-level default is sufficient
- Harden read-logs.sh against GraphQL injection and path traversal.
[Jacob Magar]
- Remove slashes from LOG_NAME regex to block path traversal (e.g.
../../etc/passwd). Only alphanumeric, dots, hyphens, underscores allowed.
- Cap LINES to 1-10000 range to prevent resource exhaustion.
- Add query script existence check before execution.
- Add query failure, empty response, and invalid JSON guards.
Resolves review thread PRRT_kwDOO6Hdxs5uvKrj
- Address 5 critical and major PR review issues. [Jacob Magar]
- Remove set -e from validate-marketplace.sh to prevent early exit on
check failures, allowing the summary to always be displayed (PRRT_kwDOO6Hdxs5uvKrc)
- Fix marketplace.json source path to point to skills/unraid instead of
./ for correct plugin directory resolution (PRRT_kwDOO6Hdxs5uvKrg)
- Fix misleading trap registration comment in unraid-api-crawl.md and
add auth note to Apollo Studio URL (PRRT_kwDOO6Hdxs5uvO2t)
- Extract duplicated cleanup-with-error-handling in main.py into
_run_shutdown_cleanup() helper (PRRT_kwDOO6Hdxs5uvO3A)
- Add input validation to read-logs.sh to prevent GraphQL injection
via LOG_NAME and LINES parameters (PRRT_kwDOO6Hdxs5uvKrj)
- Address PR review comments on test suite. [Jacob Magar]
- Rename test_start_http_401_unauthorized to test_list_http_401_unauthorized
to match the actual action="list" being tested (threads #19, #23)
- Use consistent PrefixedID format ("a"*64+":local") in test_start_container
instead of "abc123def456"*4 concatenation (thread #37)
- Refactor container_actions_require_id to use @pytest.mark.parametrize
so each action runs independently (thread #18)
- Fix docstring claiming ToolError for test that asserts success in
test_stop_mutation_returns_null (thread #26)
- Fix inaccurate comment about `in` operator checking truthiness;
it checks key existence (thread #25)
- Add edge case tests for temperature=0, temperature=null, and
logFile=null in test_storage.py (thread #31)
Resolves review threads: PRRT_kwDOO6Hdxs5uvO2-, PRRT_kwDOO6Hdxs5uvOcf,
PRRT_kwDOO6Hdxs5uu7zx, PRRT_kwDOO6Hdxs5uvO28, PRRT_kwDOO6Hdxs5uvOcp,
PRRT_kwDOO6Hdxs5uvOcn, PRRT_kwDOO6Hdxs5uvKr3
- Harden shell scripts with error handling and null guards. [Jacob
Magar]
- dashboard.sh: Add // [] jq null guard on .data.array.disks[] (L176-177)
Resolves review thread PRRT_kwDOO6Hdxs5uvO21
- dashboard.sh: Default NAME to server key when env var unset (L221)
Resolves review thread PRRT_kwDOO6Hdxs5uvO22
- unraid-query.sh: Check curl exit code, empty response, and JSON validity
before piping to jq (L112-129)
Resolves review thread PRRT_kwDOO6Hdxs5uvO24
- disk-health.sh: Guard against missing query script and invalid responses
Resolves review thread PRRT_kwDOO6Hdxs5uvKrh
- Address 54 MEDIUM/LOW priority PR review issues. [Jacob Magar]
Comprehensive fixes across Python code, shell scripts, and documentation
addressing all remaining MEDIUM and LOW priority review comments.
Python Code Fixes (27 fixes):
- tools/info.py: Simplified dispatch with lookup tables, defensive guards,
CPU fallback formatting, !s conversion flags, module-level sync assertion
- tools/docker.py: Case-insensitive container ID regex, keyword-only confirm,
module-level ALL_ACTIONS constant
- tools/virtualization.py: Normalized single-VM dict responses, unified
list/details queries
- core/client.py: Fixed HTTP client singleton race condition, compound key
substring matching for sensitive data redaction
- subscriptions/: Extracted SSL context creation to shared helper in utils.py,
replaced deprecated ssl._create_unverified_context API
- tools/array.py: Renamed parity_history to parity_status, hoisted ALL_ACTIONS
- tools/storage.py: Fixed dict(None) risks, temperature 0 falsiness bug
- tools/notifications.py, keys.py, rclone.py: Fixed dict(None) TypeError risks
- tests/: Fixed generator type annotations, added coverage for compound keys
Shell Script Fixes (13 fixes):
- dashboard.sh: Dynamic server discovery, conditional debug output, null-safe
jq, notification count guard order, removed unused variables
- unraid-query.sh: Proper JSON escaping via jq, --ignore-errors and --insecure
CLI flags, TLS verification now on by default
- validate-marketplace.sh: Removed unused YELLOW variable, defensive jq,
simplified repository URL output
Documentation Fixes (24+ fixes):
- Version consistency: Updated all references to v0.2.0 across pyproject.toml,
plugin.json, marketplace.json, MARKETPLACE.md, __init__.py, README files
- Tool count updates: Changed all "26 tools" references to "10 tools, 90 actions"
- Markdown lint: Fixed MD022, MD031, MD047 issues across multiple files
- Research docs: Fixed auth headers, removed web artifacts, corrected stale info
- Skills docs: Fixed query examples, endpoint counts, env var references
All 227 tests pass, ruff and ty checks clean.
- Update Subprotocol import and SSL handling in WebSocket modules.
[Jacob Magar]
- Change Subprotocol import from deprecated websockets.legacy.protocol
to websockets.typing (canonical location in websockets 13.x)
- Fix SSL context handling in diagnostics.py to properly build
ssl.SSLContext objects, matching the pattern in manager.py
(previously passed UNRAID_VERIFY_SSL directly which breaks
when it's a CA bundle path string)
- Resolve ruff lint issues in storage tool and tests. [Jacob Magar]
- Move _ALLOWED_LOG_PREFIXES to module level (N806: constant naming)
- Use f-string conversion flag {e!s} instead of {str(e)} (RUF010)
- Fix import block sorting in both files (I001)
- Address 18 CRITICAL+HIGH PR review comments. [Jacob Magar, config-
fixer, docker-fixer, info-fixer, keys-rclone-fixer, storage-fixer,
users-fixer, vm-fixer, websocket-fixer]
**Critical Fixes (7 issues):**
- Fix GraphQL schema field names in users tool (role→roles, remove email)
- Fix GraphQL mutation signatures (addUserInput, deleteUser input)
- Fix dict(None) TypeError guards in users tool (use `or {}` pattern)
- Fix FastAPI version constraint (0.116.1→0.115.0)
- Fix WebSocket SSL context handling (support CA bundles, bool, and None)
- Fix critical disk threshold treated as warning (split counters)
**High Priority Fixes (11 issues):**
- Fix Docker update/remove action response field mapping
- Fix path traversal vulnerability in log validation (normalize paths)
- Fix deleteApiKeys validation (check response before success)
- Fix rclone create_remote validation (check response)
- Fix keys input_data type annotation (dict[str, Any])
- Fix VM domain/domains fallback restoration
**Changes by file:**
- unraid_mcp/tools/docker.py: Response field mapping
- unraid_mcp/tools/info.py: Split critical/warning counters
- unraid_mcp/tools/storage.py: Path normalization for traversal protection
- unraid_mcp/tools/users.py: GraphQL schema + null handling
- unraid_mcp/tools/keys.py: Validation + type annotations
- unraid_mcp/tools/rclone.py: Response validation
- unraid_mcp/tools/virtualization.py: Domain fallback
- unraid_mcp/subscriptions/manager.py: SSL context creation
- pyproject.toml: FastAPI version fix
- tests/*: New tests for all fixes
**Review threads resolved:**
PRRT_kwDOO6Hdxs5uu70L, PRRT_kwDOO6Hdxs5uu70O, PRRT_kwDOO6Hdxs5uu70V,
PRRT_kwDOO6Hdxs5uu70e, PRRT_kwDOO6Hdxs5uu70i, PRRT_kwDOO6Hdxs5uu7zn,
PRRT_kwDOO6Hdxs5uu7z_, PRRT_kwDOO6Hdxs5uu7sI, PRRT_kwDOO6Hdxs5uu7sJ,
PRRT_kwDOO6Hdxs5uu7sK, PRRT_kwDOO6Hdxs5uu7Tk, PRRT_kwDOO6Hdxs5uu7Tn,
PRRT_kwDOO6Hdxs5uu7Tr, PRRT_kwDOO6Hdxs5uu7Ts, PRRT_kwDOO6Hdxs5uu7Tu,
PRRT_kwDOO6Hdxs5uu7Tv, PRRT_kwDOO6Hdxs5uu7Tw, PRRT_kwDOO6Hdxs5uu7Tx
All tests passing.
- Add type annotation to resolve mypy Literal narrowing error. [Jacob
Magar]
Other
~~~~~
- Ci: add Homelab CI framework. [Simon Diesenreiter]
- Refactor: move MCP server config inline to plugin.json. [Jacob Magar]
Move MCP server configuration from standalone .mcp.json to inline
definition in plugin.json. This consolidates all plugin metadata
in a single location.
- Add type: stdio and env fields to inline config
- Remove redundant .mcp.json file
- Maintains same functionality with cleaner structure
- Feat: add MCP server configuration for Claude Code plugin integration.
[Jacob Magar]
Add .mcp.json to configure the Unraid MCP server as a stdio-based MCP
server for Claude Code plugin integration. This allows Claude Code to
automatically start and connect to the server when the plugin is loaded.
- Type: stdio (standard input/output communication)
- Command: uv run unraid-mcp-server
- Forces stdio transport mode via UNRAID_MCP_TRANSPORT env var
- Docs: fix markdown lint, broken links, stale counts, and publishing
guidance. [Jacob Magar]
- Fix broken ToC anchors in competitive-analysis.md (MD051)
- Add blank lines before code blocks in api-reference.md (MD031)
- Add language identifiers to directory tree code blocks in MARKETPLACE.md and skills/unraid/README.md (MD040)
- Fix size unit guidance conflict: clarify disk sizes are KB, memory is bytes
- Update stale "90 actions" references to "76 actions" across research docs
- Fix coverage table terminology and clarify 22% coverage calculation
- Recommend PyPI Trusted Publishing (OIDC) over API token secrets in PUBLISHING.md
- Update action count in .claude-plugin/README.md
Resolves review threads: PRRT_kwDOO6Hdxs5uvO2m, PRRT_kwDOO6Hdxs5uvO2o,
PRRT_kwDOO6Hdxs5uvO2r, PRRT_kwDOO6Hdxs5uvOcl, PRRT_kwDOO6Hdxs5uvOcr,
PRRT_kwDOO6Hdxs5uvKrq, PRRT_kwDOO6Hdxs5uvO2u, PRRT_kwDOO6Hdxs5uvO2w,
PRRT_kwDOO6Hdxs5uvO2z, PRRT_kwDOO6Hdxs5uu7zl
- Feat: enhance test suite with 275 new tests across 4 validation
categories. [Claude, Jacob Magar]
Add comprehensive test coverage beyond unit tests:
- Schema validation (93 tests): Validate all GraphQL queries/mutations against extracted Unraid API schema
- HTTP layer (88 tests): Test request construction, timeouts, and error handling at httpx level
- Subscriptions (55 tests): WebSocket lifecycle, reconnection, and protocol validation
- Safety audit (39 tests): Enforce destructive action confirmation requirements
Total test count increased from 210 to 485 (130% increase), all passing in 5.91s.
New dependencies:
- graphql-core>=3.2.0 for schema validation
- respx>=0.22.0 for HTTP layer mocking
Files created:
- docs/unraid-schema.graphql (150-type GraphQL schema)
- tests/schema/test_query_validation.py
- tests/http_layer/test_request_construction.py
- tests/integration/test_subscriptions.py
- tests/safety/test_destructive_guards.py
- Feat: harden API safety and expand command docs with full test
coverage. [Jacob Magar]
- Refactor: move plugin manifest to repository root per Claude Code best
practices. [Jacob Magar]
- Move plugin.json from skills/unraid/.claude-plugin/ to .claude-plugin/
- Update validation script to use correct plugin manifest path
- Add plugin structure section to root README.md
- Add installation instructions to skills/unraid/README.md
- Aligns with Claude Code's expectation for source: './' in marketplace.json
- Chore: enhance project metadata, tooling, and documentation. [Claude,
Jacob Magar]
**Project Configuration:**
- Enhance pyproject.toml with comprehensive metadata, keywords, and classifiers
- Add LICENSE file (MIT) for proper open-source distribution
- Add PUBLISHING.md with comprehensive publishing guidelines
- Update .gitignore to exclude tool artifacts (.cache, .pytest_cache, .ruff_cache, .ty_cache)
- Ignore documentation working directories (.docs, .full-review, docs/plans, docs/sessions)
**Documentation:**
- Add extensive Unraid API research documentation
- API source code analysis and resolver mapping
- Competitive analysis and feature gap assessment
- Release notes analysis (7.0.0, 7.1.0, 7.2.0)
- Connect platform overview and remote access documentation
- Document known API patterns, limitations, and edge cases
**Testing & Code Quality:**
- Expand test coverage across all tool modules
- Add destructive action confirmation tests
- Improve test assertions and error case validation
- Refine type annotations for better static analysis
**Tool Improvements:**
- Enhance error handling consistency across all tools
- Improve type safety with explicit type annotations
- Refine GraphQL query construction patterns
- Better handling of optional parameters and edge cases
This commit prepares the project for v0.2.0 release with improved
metadata, comprehensive documentation, and enhanced code quality.
- Feat: consolidate 26 tools into 10 tools with 90 actions. [Jacob
Magar]
Refactor the entire tool layer to use the consolidated action pattern
(action: Literal[...] with QUERIES/MUTATIONS dicts). This reduces LLM
context from ~12k to ~5k tokens while adding ~60 new API capabilities.
New tools: unraid_info (19 actions), unraid_array (12), unraid_notifications (9),
unraid_users (8), unraid_keys (5). Rewritten: unraid_docker (15), unraid_vm (9),
unraid_storage (6), unraid_rclone (4), unraid_health (3).
Includes 129 tests across 10 test files, code review fixes for 16 issues
(severity ordering, PrefixedID regex, sensitive var redaction, etc.).
Removes tools/system.py (replaced by tools/info.py). Version bumped to 0.2.0.
- Chore: update .gitignore. [Jacob Magar]
- Move pid and log files to /tmp directory. [Claude, Jacob Magar]
- Update dev.sh to use /tmp for LOG_DIR instead of PROJECT_DIR/logs
- Update settings.py to use /tmp for LOGS_DIR instead of PROJECT_ROOT/logs
- This change moves both pid files and log files to the temporary directory
🤖 Generated with [Claude Code](https://claude.ai/code)
- Remove env_file from docker-compose and use explicit environment
variables. [Claude, Jacob Magar]
- Remove env_file directive from docker-compose.yml to eliminate .env file dependency
- Add explicit environment variable declarations with default values using ${VAR:-default} syntax
- Update port mapping to use UNRAID_MCP_PORT environment variable for both host and container
- Include all 11 environment variables used by the application with proper defaults
- Update README.md Docker deployment instructions to use export commands instead of .env files
- Update manual Docker run command to use -e flags instead of --env-file
This makes Docker deployment self-contained and follows container best practices.
🤖 Generated with [Claude Code](https://claude.ai/code)
- Replace log rotation with 10MB overwrite behavior. [Claude, Jacob
Magar]
- Create OverwriteFileHandler class that caps log files at 10MB and overwrites instead of rotating
- Remove RotatingFileHandler dependency and backup file creation
- Add reset marker logging when file limit is reached for troubleshooting
- Update all logger configurations (main, FastMCP, and root loggers)
- Increase file size limit from 5MB to 10MB as requested
- Maintain existing Rich console formatting and error handling
🤖 Generated with [Claude Code](https://claude.ai/code)
- Align documentation and Docker configuration with current
implementation. [Claude, Jacob Magar]
- Fix README.md: Make Docker deployment recommended, remove duplicate installation section
- Fix Dockerfile: Copy correct source files (unraid_mcp/, uv.lock, README.md) instead of non-existent unraid_mcp_server.py
- Update docker-compose.yml: Enable build configuration and use .env instead of .env.local
- Add missing environment variables to .env.example and .env: UNRAID_AUTO_START_SUBSCRIPTIONS, UNRAID_MAX_RECONNECT_ATTEMPTS
- Fix CLAUDE.md: Correct environment hierarchy documentation (../env.local → ../.env.local)
- Remove unused unraid-schema.json file
🤖 Generated with [Claude Code](https://claude.ai/code)
- Lintfree. [Jacob Magar]
- Add Claude Code agent configuration and GraphQL introspection.
[Claude, Jacob Magar]
- Added KFC (Kent Feature Creator) spec workflow agents for requirements, design, tasks, testing, implementation and evaluation
- Added Claude Code settings configuration for agent workflows
- Added GraphQL introspection query and schema files for Unraid API exploration
- Updated development script with additional debugging and schema inspection capabilities
- Enhanced logging configuration with structured formatting
- Updated pyproject.toml dependencies and uv.lock
🤖 Generated with [Claude Code](https://claude.ai/code)
- Remove unused MCP resources and update documentation. [Claude, Jacob
Magar]
- Remove array_status, system_info, notifications_overview, and parity_status resources
- Keep only logs_stream resource (unraid://logs/stream) which is working properly
- Update README.md with current resource documentation and modern docker compose syntax
- Fix import path issues that were causing subscription errors
- Update environment configuration examples
- Clean up subscription manager to only include working log streaming
🤖 Generated with [Claude Code](https://claude.ai/code)
- Migrate to uv and FastMCP architecture with comprehensive tooling.
[Claude, Jacob Magar]
- Replace pip/requirements.txt with uv and pyproject.toml
- Restructure as single-file MCP server using FastMCP
- Add comprehensive Unraid management tools (containers, VMs, storage, logs)
- Implement multiple transport support (streamable-http, SSE, stdio)
- Add robust error handling and timeout management
- Include project documentation and API feature tracking
- Remove outdated cline documentation structure
🤖 Generated with [Claude Code](https://claude.ai/code)
- Update docker-compose.yml. [Jacob Magar]

26
Makefile Normal file
View File

@@ -0,0 +1,26 @@
.ONESHELL:
VERSION ?= $(shell cat ./VERSION)
.PHONY: issetup
issetup:
@[ -f .git/hooks/commit-msg ] || [ $SKIP_MAKE_SETUP_CHECK = "true" ] || (echo "You must run 'make setup' first to initialize the repo!" && exit 1)
.PHONY: setup
setup:
@cp .gitea/conventional_commits/commit-msg .git/hooks/
.PHONY: help
help: ## Show the help.
@echo "Usage: make <target>"
@echo ""
@echo "Targets:"
@fgrep "##" Makefile | fgrep -v fgrep
.PHONY: release
release: issetup ## Create a new tag for release.
@./.gitea/conventional_commits/generate-version.sh
.PHONY: build
build: issetup
@docker build -t unraid-mcp:${VERSION} .

3
NOTICE.md Normal file
View File

@@ -0,0 +1,3 @@
# Notice
This is a fork of an externally maintained repository. Only intended for internal use in HomeLab!

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.0.2

View File

@@ -5,11 +5,6 @@ services:
dockerfile: Dockerfile dockerfile: Dockerfile
container_name: unraid-mcp container_name: unraid-mcp
restart: unless-stopped restart: unless-stopped
read_only: true
cap_drop:
- ALL
tmpfs:
- /tmp:noexec,nosuid,size=64m
ports: ports:
# HostPort:ContainerPort (maps to UNRAID_MCP_PORT inside the container, default 6970) # HostPort:ContainerPort (maps to UNRAID_MCP_PORT inside the container, default 6970)
# Change the host port (left side) if 6970 is already in use on your host # Change the host port (left side) if 6970 is already in use on your host

View File

@@ -77,6 +77,7 @@ dependencies = [
"uvicorn[standard]>=0.35.0", "uvicorn[standard]>=0.35.0",
"websockets>=15.0.1", "websockets>=15.0.1",
"rich>=14.1.0", "rich>=14.1.0",
"pytz>=2025.2",
] ]
# ============================================================================ # ============================================================================
@@ -169,8 +170,6 @@ select = [
"PERF", "PERF",
# Ruff-specific rules # Ruff-specific rules
"RUF", "RUF",
# flake8-bandit (security)
"S",
] ]
ignore = [ ignore = [
"E501", # line too long (handled by ruff formatter) "E501", # line too long (handled by ruff formatter)
@@ -286,6 +285,7 @@ dev = [
"pytest-asyncio>=1.2.0", "pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0", "pytest-cov>=7.0.0",
"respx>=0.22.0", "respx>=0.22.0",
"types-pytz>=2025.2.0.20250809",
"ty>=0.0.15", "ty>=0.0.15",
"ruff>=0.12.8", "ruff>=0.12.8",
"build>=1.2.2", "build>=1.2.2",

View File

@@ -19,7 +19,6 @@ from tests.conftest import make_tool_fn
from unraid_mcp.core.client import DEFAULT_TIMEOUT, DISK_TIMEOUT, make_graphql_request from unraid_mcp.core.client import DEFAULT_TIMEOUT, DISK_TIMEOUT, make_graphql_request
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Shared fixtures # Shared fixtures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -159,43 +158,43 @@ class TestHttpErrorHandling:
@respx.mock @respx.mock
async def test_http_401_raises_tool_error(self) -> None: async def test_http_401_raises_tool_error(self) -> None:
respx.post(API_URL).mock(return_value=httpx.Response(401, text="Unauthorized")) respx.post(API_URL).mock(return_value=httpx.Response(401, text="Unauthorized"))
with pytest.raises(ToolError, match="Unraid API returned HTTP 401"): with pytest.raises(ToolError, match="HTTP error 401"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@respx.mock @respx.mock
async def test_http_403_raises_tool_error(self) -> None: async def test_http_403_raises_tool_error(self) -> None:
respx.post(API_URL).mock(return_value=httpx.Response(403, text="Forbidden")) respx.post(API_URL).mock(return_value=httpx.Response(403, text="Forbidden"))
with pytest.raises(ToolError, match="Unraid API returned HTTP 403"): with pytest.raises(ToolError, match="HTTP error 403"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@respx.mock @respx.mock
async def test_http_500_raises_tool_error(self) -> None: async def test_http_500_raises_tool_error(self) -> None:
respx.post(API_URL).mock(return_value=httpx.Response(500, text="Internal Server Error")) respx.post(API_URL).mock(return_value=httpx.Response(500, text="Internal Server Error"))
with pytest.raises(ToolError, match="Unraid API returned HTTP 500"): with pytest.raises(ToolError, match="HTTP error 500"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@respx.mock @respx.mock
async def test_http_503_raises_tool_error(self) -> None: async def test_http_503_raises_tool_error(self) -> None:
respx.post(API_URL).mock(return_value=httpx.Response(503, text="Service Unavailable")) respx.post(API_URL).mock(return_value=httpx.Response(503, text="Service Unavailable"))
with pytest.raises(ToolError, match="Unraid API returned HTTP 503"): with pytest.raises(ToolError, match="HTTP error 503"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@respx.mock @respx.mock
async def test_network_connection_error(self) -> None: async def test_network_connection_error(self) -> None:
respx.post(API_URL).mock(side_effect=httpx.ConnectError("Connection refused")) respx.post(API_URL).mock(side_effect=httpx.ConnectError("Connection refused"))
with pytest.raises(ToolError, match="Network error connecting to Unraid API"): with pytest.raises(ToolError, match="Network connection error"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@respx.mock @respx.mock
async def test_network_timeout_error(self) -> None: async def test_network_timeout_error(self) -> None:
respx.post(API_URL).mock(side_effect=httpx.ReadTimeout("Read timed out")) respx.post(API_URL).mock(side_effect=httpx.ReadTimeout("Read timed out"))
with pytest.raises(ToolError, match="Network error connecting to Unraid API"): with pytest.raises(ToolError, match="Network connection error"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@respx.mock @respx.mock
async def test_invalid_json_response(self) -> None: async def test_invalid_json_response(self) -> None:
respx.post(API_URL).mock(return_value=httpx.Response(200, text="not json")) respx.post(API_URL).mock(return_value=httpx.Response(200, text="not json"))
with pytest.raises(ToolError, match=r"invalid response.*not valid JSON"): with pytest.raises(ToolError, match="Invalid JSON response"):
await make_graphql_request("query { online }") await make_graphql_request("query { online }")
@@ -583,7 +582,7 @@ class TestVMToolRequests:
return_value=_graphql_response({"vm": {"stop": True}}) return_value=_graphql_response({"vm": {"stop": True}})
) )
tool = self._get_tool() tool = self._get_tool()
await tool(action="stop", vm_id="vm-456") result = await tool(action="stop", vm_id="vm-456")
body = _extract_request_body(route.calls.last.request) body = _extract_request_body(route.calls.last.request)
assert "StopVM" in body["query"] assert "StopVM" in body["query"]
assert body["variables"] == {"id": "vm-456"} assert body["variables"] == {"id": "vm-456"}
@@ -869,14 +868,14 @@ class TestNotificationsToolRequests:
title="Test", title="Test",
subject="Sub", subject="Sub",
description="Desc", description="Desc",
importance="normal", importance="info",
) )
body = _extract_request_body(route.calls.last.request) body = _extract_request_body(route.calls.last.request)
assert "CreateNotification" in body["query"] assert "CreateNotification" in body["query"]
inp = body["variables"]["input"] inp = body["variables"]["input"]
assert inp["title"] == "Test" assert inp["title"] == "Test"
assert inp["subject"] == "Sub" assert inp["subject"] == "Sub"
assert inp["importance"] == "NORMAL" # uppercased from "normal" assert inp["importance"] == "INFO" # uppercased
@respx.mock @respx.mock
async def test_archive_sends_id_variable(self) -> None: async def test_archive_sends_id_variable(self) -> None:
@@ -1257,7 +1256,7 @@ class TestCrossCuttingConcerns:
tool = make_tool_fn( tool = make_tool_fn(
"unraid_mcp.tools.info", "register_info_tool", "unraid_info" "unraid_mcp.tools.info", "register_info_tool", "unraid_info"
) )
with pytest.raises(ToolError, match="Unraid API returned HTTP 500"): with pytest.raises(ToolError, match="HTTP error 500"):
await tool(action="online") await tool(action="online")
@respx.mock @respx.mock
@@ -1269,7 +1268,7 @@ class TestCrossCuttingConcerns:
tool = make_tool_fn( tool = make_tool_fn(
"unraid_mcp.tools.info", "register_info_tool", "unraid_info" "unraid_mcp.tools.info", "register_info_tool", "unraid_info"
) )
with pytest.raises(ToolError, match="Network error connecting to Unraid API"): with pytest.raises(ToolError, match="Network connection error"):
await tool(action="online") await tool(action="online")
@respx.mock @respx.mock

View File

@@ -7,7 +7,7 @@ data management without requiring a live Unraid server.
import asyncio import asyncio
import json import json
from datetime import UTC, datetime from datetime import datetime
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -16,7 +16,6 @@ import websockets.exceptions
from unraid_mcp.subscriptions.manager import SubscriptionManager from unraid_mcp.subscriptions.manager import SubscriptionManager
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration
@@ -84,7 +83,7 @@ SAMPLE_QUERY = "subscription { test { value } }"
# Shared patch targets # Shared patch targets
_WS_CONNECT = "unraid_mcp.subscriptions.manager.websockets.connect" _WS_CONNECT = "unraid_mcp.subscriptions.manager.websockets.connect"
_API_URL = "unraid_mcp.subscriptions.utils.UNRAID_API_URL" _API_URL = "unraid_mcp.subscriptions.manager.UNRAID_API_URL"
_API_KEY = "unraid_mcp.subscriptions.manager.UNRAID_API_KEY" _API_KEY = "unraid_mcp.subscriptions.manager.UNRAID_API_KEY"
_SSL_CTX = "unraid_mcp.subscriptions.manager.build_ws_ssl_context" _SSL_CTX = "unraid_mcp.subscriptions.manager.build_ws_ssl_context"
_SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep" _SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep"
@@ -101,7 +100,7 @@ class TestSubscriptionManagerInit:
mgr = SubscriptionManager() mgr = SubscriptionManager()
assert mgr.active_subscriptions == {} assert mgr.active_subscriptions == {}
assert mgr.resource_data == {} assert mgr.resource_data == {}
assert not hasattr(mgr, "websocket") assert mgr.websocket is None
def test_default_auto_start_enabled(self) -> None: def test_default_auto_start_enabled(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
@@ -721,20 +720,20 @@ class TestWebSocketURLConstruction:
class TestResourceData: class TestResourceData:
async def test_get_resource_data_returns_none_when_empty(self) -> None: def test_get_resource_data_returns_none_when_empty(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
assert await mgr.get_resource_data("nonexistent") is None assert mgr.get_resource_data("nonexistent") is None
async def test_get_resource_data_returns_stored_data(self) -> None: def test_get_resource_data_returns_stored_data(self) -> None:
from unraid_mcp.core.types import SubscriptionData from unraid_mcp.core.types import SubscriptionData
mgr = SubscriptionManager() mgr = SubscriptionManager()
mgr.resource_data["test"] = SubscriptionData( mgr.resource_data["test"] = SubscriptionData(
data={"key": "value"}, data={"key": "value"},
last_updated=datetime.now(UTC), last_updated=datetime.now(),
subscription_type="test", subscription_type="test",
) )
result = await mgr.get_resource_data("test") result = mgr.get_resource_data("test")
assert result == {"key": "value"} assert result == {"key": "value"}
def test_list_active_subscriptions_empty(self) -> None: def test_list_active_subscriptions_empty(self) -> None:
@@ -756,46 +755,46 @@ class TestResourceData:
class TestSubscriptionStatus: class TestSubscriptionStatus:
async def test_status_includes_all_configured_subscriptions(self) -> None: def test_status_includes_all_configured_subscriptions(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
status = await mgr.get_subscription_status() status = mgr.get_subscription_status()
for name in mgr.subscription_configs: for name in mgr.subscription_configs:
assert name in status assert name in status
async def test_status_default_connection_state(self) -> None: def test_status_default_connection_state(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
status = await mgr.get_subscription_status() status = mgr.get_subscription_status()
for sub_status in status.values(): for sub_status in status.values():
assert sub_status["runtime"]["connection_state"] == "not_started" assert sub_status["runtime"]["connection_state"] == "not_started"
async def test_status_shows_active_flag(self) -> None: def test_status_shows_active_flag(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
mgr.active_subscriptions["logFileSubscription"] = MagicMock() mgr.active_subscriptions["logFileSubscription"] = MagicMock()
status = await mgr.get_subscription_status() status = mgr.get_subscription_status()
assert status["logFileSubscription"]["runtime"]["active"] is True assert status["logFileSubscription"]["runtime"]["active"] is True
async def test_status_shows_data_availability(self) -> None: def test_status_shows_data_availability(self) -> None:
from unraid_mcp.core.types import SubscriptionData from unraid_mcp.core.types import SubscriptionData
mgr = SubscriptionManager() mgr = SubscriptionManager()
mgr.resource_data["logFileSubscription"] = SubscriptionData( mgr.resource_data["logFileSubscription"] = SubscriptionData(
data={"log": "content"}, data={"log": "content"},
last_updated=datetime.now(UTC), last_updated=datetime.now(),
subscription_type="logFileSubscription", subscription_type="logFileSubscription",
) )
status = await mgr.get_subscription_status() status = mgr.get_subscription_status()
assert status["logFileSubscription"]["data"]["available"] is True assert status["logFileSubscription"]["data"]["available"] is True
async def test_status_shows_error_info(self) -> None: def test_status_shows_error_info(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
mgr.last_error["logFileSubscription"] = "Test error message" mgr.last_error["logFileSubscription"] = "Test error message"
status = await mgr.get_subscription_status() status = mgr.get_subscription_status()
assert status["logFileSubscription"]["runtime"]["last_error"] == "Test error message" assert status["logFileSubscription"]["runtime"]["last_error"] == "Test error message"
async def test_status_reconnect_attempts_tracked(self) -> None: def test_status_reconnect_attempts_tracked(self) -> None:
mgr = SubscriptionManager() mgr = SubscriptionManager()
mgr.reconnect_attempts["logFileSubscription"] = 3 mgr.reconnect_attempts["logFileSubscription"] = 3
status = await mgr.get_subscription_status() status = mgr.get_subscription_status()
assert status["logFileSubscription"]["runtime"]["reconnect_attempts"] == 3 assert status["logFileSubscription"]["runtime"]["reconnect_attempts"] == 3

View File

@@ -10,12 +10,6 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
# conftest.py is the shared test-helper module for this project.
# pytest automatically adds tests/ to sys.path, making it importable here
# without a package __init__.py. Do NOT add tests/__init__.py — it breaks
# conftest.py's fixture auto-discovery.
from conftest import make_tool_fn
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
# Import DESTRUCTIVE_ACTIONS sets from every tool module that defines one # Import DESTRUCTIVE_ACTIONS sets from every tool module that defines one
@@ -30,6 +24,10 @@ from unraid_mcp.tools.rclone import MUTATIONS as RCLONE_MUTATIONS
from unraid_mcp.tools.virtualization import DESTRUCTIVE_ACTIONS as VM_DESTRUCTIVE from unraid_mcp.tools.virtualization import DESTRUCTIVE_ACTIONS as VM_DESTRUCTIVE
from unraid_mcp.tools.virtualization import MUTATIONS as VM_MUTATIONS from unraid_mcp.tools.virtualization import MUTATIONS as VM_MUTATIONS
# Centralized import for make_tool_fn helper
# conftest.py sits in tests/ and is importable without __init__.py
from conftest import make_tool_fn
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Known destructive actions registry (ground truth for this audit) # Known destructive actions registry (ground truth for this audit)
@@ -41,7 +39,7 @@ KNOWN_DESTRUCTIVE: dict[str, dict[str, set[str]]] = {
"module": "unraid_mcp.tools.docker", "module": "unraid_mcp.tools.docker",
"register_fn": "register_docker_tool", "register_fn": "register_docker_tool",
"tool_name": "unraid_docker", "tool_name": "unraid_docker",
"actions": {"remove", "update_all"}, "actions": {"remove"},
"runtime_set": DOCKER_DESTRUCTIVE, "runtime_set": DOCKER_DESTRUCTIVE,
}, },
"vm": { "vm": {
@@ -128,11 +126,9 @@ class TestDestructiveActionRegistries:
missing: list[str] = [] missing: list[str] = []
for tool_key, mutations in all_mutations.items(): for tool_key, mutations in all_mutations.items():
destructive = all_destructive[tool_key] destructive = all_destructive[tool_key]
missing.extend( for action_name in mutations:
f"{tool_key}/{action_name}" if ("delete" in action_name or "remove" in action_name) and action_name not in destructive:
for action_name in mutations missing.append(f"{tool_key}/{action_name}")
if ("delete" in action_name or "remove" in action_name) and action_name not in destructive
)
assert not missing, ( assert not missing, (
f"Mutations with 'delete'/'remove' not in DESTRUCTIVE_ACTIONS: {missing}" f"Mutations with 'delete'/'remove' not in DESTRUCTIVE_ACTIONS: {missing}"
) )
@@ -147,7 +143,6 @@ class TestDestructiveActionRegistries:
_DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [ _DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [
# Docker # Docker
("docker", "remove", {"container_id": "abc123"}), ("docker", "remove", {"container_id": "abc123"}),
("docker", "update_all", {}),
# VM # VM
("vm", "force_stop", {"vm_id": "test-vm-uuid"}), ("vm", "force_stop", {"vm_id": "test-vm-uuid"}),
("vm", "reset", {"vm_id": "test-vm-uuid"}), ("vm", "reset", {"vm_id": "test-vm-uuid"}),
@@ -273,15 +268,6 @@ class TestConfirmationGuards:
class TestConfirmAllowsExecution: class TestConfirmAllowsExecution:
"""Destructive actions with confirm=True should reach the GraphQL layer.""" """Destructive actions with confirm=True should reach the GraphQL layer."""
async def test_docker_update_all_with_confirm(self, _mock_docker_graphql: AsyncMock) -> None:
_mock_docker_graphql.return_value = {
"docker": {"updateAllContainers": [{"id": "c1", "names": ["app"], "state": "running", "status": "Up"}]}
}
tool_fn = make_tool_fn("unraid_mcp.tools.docker", "register_docker_tool", "unraid_docker")
result = await tool_fn(action="update_all", confirm=True)
assert result["success"] is True
assert result["action"] == "update_all"
async def test_docker_remove_with_confirm(self, _mock_docker_graphql: AsyncMock) -> None: async def test_docker_remove_with_confirm(self, _mock_docker_graphql: AsyncMock) -> None:
cid = "a" * 64 + ":local" cid = "a" * 64 + ":local"
_mock_docker_graphql.side_effect = [ _mock_docker_graphql.side_effect = [

View File

@@ -384,16 +384,10 @@ class TestVmQueries:
errors = _validate_operation(schema, QUERIES["list"]) errors = _validate_operation(schema, QUERIES["list"])
assert not errors, f"list query validation failed: {errors}" assert not errors, f"list query validation failed: {errors}"
def test_details_query(self, schema: GraphQLSchema) -> None:
from unraid_mcp.tools.virtualization import QUERIES
errors = _validate_operation(schema, QUERIES["details"])
assert not errors, f"details query validation failed: {errors}"
def test_all_vm_queries_covered(self, schema: GraphQLSchema) -> None: def test_all_vm_queries_covered(self, schema: GraphQLSchema) -> None:
from unraid_mcp.tools.virtualization import QUERIES from unraid_mcp.tools.virtualization import QUERIES
assert set(QUERIES.keys()) == {"list", "details"} assert set(QUERIES.keys()) == {"list"}
class TestVmMutations: class TestVmMutations:

View File

@@ -84,7 +84,7 @@ class TestArrayActions:
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("disk error") _mock_graphql.side_effect = RuntimeError("disk error")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="Failed to execute array/parity_status"): with pytest.raises(ToolError, match="disk error"):
await tool_fn(action="parity_status") await tool_fn(action="parity_status")

View File

@@ -1,7 +1,6 @@
"""Tests for unraid_mcp.core.client — GraphQL client infrastructure.""" """Tests for unraid_mcp.core.client — GraphQL client infrastructure."""
import json import json
import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
@@ -10,11 +9,9 @@ import pytest
from unraid_mcp.core.client import ( from unraid_mcp.core.client import (
DEFAULT_TIMEOUT, DEFAULT_TIMEOUT,
DISK_TIMEOUT, DISK_TIMEOUT,
_QueryCache, _redact_sensitive,
_RateLimiter,
is_idempotent_error, is_idempotent_error,
make_graphql_request, make_graphql_request,
redact_sensitive,
) )
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
@@ -60,7 +57,7 @@ class TestIsIdempotentError:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# redact_sensitive # _redact_sensitive
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -69,36 +66,36 @@ class TestRedactSensitive:
def test_flat_dict(self) -> None: def test_flat_dict(self) -> None:
data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"} data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"}
result = redact_sensitive(data) result = _redact_sensitive(data)
assert result["username"] == "admin" assert result["username"] == "admin"
assert result["password"] == "***" assert result["password"] == "***"
assert result["host"] == "10.0.0.1" assert result["host"] == "10.0.0.1"
def test_nested_dict(self) -> None: def test_nested_dict(self) -> None:
data = {"config": {"apiKey": "abc123", "url": "http://host"}} data = {"config": {"apiKey": "abc123", "url": "http://host"}}
result = redact_sensitive(data) result = _redact_sensitive(data)
assert result["config"]["apiKey"] == "***" assert result["config"]["apiKey"] == "***"
assert result["config"]["url"] == "http://host" assert result["config"]["url"] == "http://host"
def test_list_of_dicts(self) -> None: def test_list_of_dicts(self) -> None:
data = [{"token": "t1"}, {"name": "safe"}] data = [{"token": "t1"}, {"name": "safe"}]
result = redact_sensitive(data) result = _redact_sensitive(data)
assert result[0]["token"] == "***" assert result[0]["token"] == "***"
assert result[1]["name"] == "safe" assert result[1]["name"] == "safe"
def test_deeply_nested(self) -> None: def test_deeply_nested(self) -> None:
data = {"a": {"b": {"c": {"secret": "deep"}}}} data = {"a": {"b": {"c": {"secret": "deep"}}}}
result = redact_sensitive(data) result = _redact_sensitive(data)
assert result["a"]["b"]["c"]["secret"] == "***" assert result["a"]["b"]["c"]["secret"] == "***"
def test_non_dict_passthrough(self) -> None: def test_non_dict_passthrough(self) -> None:
assert redact_sensitive("plain_string") == "plain_string" assert _redact_sensitive("plain_string") == "plain_string"
assert redact_sensitive(42) == 42 assert _redact_sensitive(42) == 42
assert redact_sensitive(None) is None assert _redact_sensitive(None) is None
def test_case_insensitive_keys(self) -> None: def test_case_insensitive_keys(self) -> None:
data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"} data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"}
result = redact_sensitive(data) result = _redact_sensitive(data)
for v in result.values(): for v in result.values():
assert v == "***" assert v == "***"
@@ -112,7 +109,7 @@ class TestRedactSensitive:
"username": "safe", "username": "safe",
"host": "safe", "host": "safe",
} }
result = redact_sensitive(data) result = _redact_sensitive(data)
assert result["user_password"] == "***" assert result["user_password"] == "***"
assert result["api_key_value"] == "***" assert result["api_key_value"] == "***"
assert result["auth_token_expiry"] == "***" assert result["auth_token_expiry"] == "***"
@@ -122,26 +119,12 @@ class TestRedactSensitive:
def test_mixed_list_content(self) -> None: def test_mixed_list_content(self) -> None:
data = [{"key": "val"}, "string", 123, [{"token": "inner"}]] data = [{"key": "val"}, "string", 123, [{"token": "inner"}]]
result = redact_sensitive(data) result = _redact_sensitive(data)
assert result[0]["key"] == "***" assert result[0]["key"] == "***"
assert result[1] == "string" assert result[1] == "string"
assert result[2] == 123 assert result[2] == 123
assert result[3][0]["token"] == "***" assert result[3][0]["token"] == "***"
def test_new_sensitive_keys_are_redacted(self) -> None:
"""PR-added keys: authorization, cookie, session, credential, passphrase, jwt."""
data = {
"authorization": "Bearer token123",
"cookie": "session=abc",
"jwt": "eyJ...",
"credential": "secret_cred",
"passphrase": "hunter2",
"session": "sess_id",
}
result = redact_sensitive(data)
for key, val in result.items():
assert val == "***", f"Key '{key}' was not redacted"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Timeout constants # Timeout constants
@@ -291,7 +274,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="Unraid API returned HTTP 401"), pytest.raises(ToolError, match="HTTP error 401"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -309,7 +292,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="Unraid API returned HTTP 500"), pytest.raises(ToolError, match="HTTP error 500"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -327,7 +310,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="Unraid API returned HTTP 503"), pytest.raises(ToolError, match="HTTP error 503"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -337,7 +320,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="Network error connecting to Unraid API"), pytest.raises(ToolError, match="Network connection error"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -347,7 +330,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="Network error connecting to Unraid API"), pytest.raises(ToolError, match="Network connection error"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -361,7 +344,7 @@ class TestMakeGraphQLRequestErrors:
with ( with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client), patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match=r"invalid response.*not valid JSON"), pytest.raises(ToolError, match="Invalid JSON response"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
@@ -481,240 +464,3 @@ class TestGraphQLErrorHandling:
pytest.raises(ToolError, match="GraphQL API error"), pytest.raises(ToolError, match="GraphQL API error"),
): ):
await make_graphql_request("{ info }") await make_graphql_request("{ info }")
# ---------------------------------------------------------------------------
# _RateLimiter
# ---------------------------------------------------------------------------
class TestRateLimiter:
"""Unit tests for the token bucket rate limiter."""
async def test_acquire_consumes_one_token(self) -> None:
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
initial = limiter.tokens
await limiter.acquire()
assert limiter.tokens == pytest.approx(initial - 1, abs=1e-3)
async def test_acquire_succeeds_when_tokens_available(self) -> None:
limiter = _RateLimiter(max_tokens=5, refill_rate=1.0)
# Should complete without sleeping
for _ in range(5):
await limiter.acquire()
# _refill() runs during each acquire() call and adds a tiny time-based
# amount; check < 1.0 (not enough for another immediate request) rather
# than == 0.0 to avoid flakiness from timing.
assert limiter.tokens < 1.0
async def test_tokens_do_not_exceed_max(self) -> None:
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
# Force refill with large elapsed time
limiter.last_refill = time.monotonic() - 100.0 # 100 seconds ago
limiter._refill()
assert limiter.tokens == 10.0 # Capped at max_tokens
async def test_refill_adds_tokens_based_on_elapsed(self) -> None:
limiter = _RateLimiter(max_tokens=100, refill_rate=10.0)
limiter.tokens = 0.0
limiter.last_refill = time.monotonic() - 1.0 # 1 second ago
limiter._refill()
# Should have refilled ~10 tokens (10.0 rate * 1.0 sec)
assert 9.5 < limiter.tokens < 10.5
async def test_acquire_sleeps_when_no_tokens(self) -> None:
"""When tokens are exhausted, acquire should sleep before consuming."""
limiter = _RateLimiter(max_tokens=1, refill_rate=1.0)
limiter.tokens = 0.0
sleep_calls = []
async def fake_sleep(duration: float) -> None:
sleep_calls.append(duration)
# Simulate refill by advancing last_refill so tokens replenish
limiter.tokens = 1.0
limiter.last_refill = time.monotonic()
with patch("unraid_mcp.core.client.asyncio.sleep", side_effect=fake_sleep):
await limiter.acquire()
assert len(sleep_calls) == 1
assert sleep_calls[0] > 0
async def test_default_params_match_api_limits(self) -> None:
"""Default rate limiter must use 90 tokens at 9.0/sec (10% headroom from 100/10s)."""
limiter = _RateLimiter()
assert limiter.max_tokens == 90
assert limiter.refill_rate == 9.0
# ---------------------------------------------------------------------------
# _QueryCache
# ---------------------------------------------------------------------------
class TestQueryCache:
"""Unit tests for the TTL query cache."""
def test_miss_on_empty_cache(self) -> None:
cache = _QueryCache()
assert cache.get("{ info }", None) is None
def test_put_and_get_hit(self) -> None:
cache = _QueryCache()
data = {"result": "ok"}
cache.put("GetNetworkConfig { }", None, data)
result = cache.get("GetNetworkConfig { }", None)
assert result == data
def test_expired_entry_returns_none(self) -> None:
cache = _QueryCache()
data = {"result": "ok"}
cache.put("GetNetworkConfig { }", None, data)
# Manually expire the entry
key = cache._cache_key("GetNetworkConfig { }", None)
cache._store[key] = (time.monotonic() - 1.0, data) # expired 1 sec ago
assert cache.get("GetNetworkConfig { }", None) is None
def test_invalidate_all_clears_store(self) -> None:
cache = _QueryCache()
cache.put("GetNetworkConfig { }", None, {"x": 1})
cache.put("GetOwner { }", None, {"y": 2})
assert len(cache._store) == 2
cache.invalidate_all()
assert len(cache._store) == 0
def test_variables_affect_cache_key(self) -> None:
"""Different variables produce different cache keys."""
cache = _QueryCache()
q = "GetNetworkConfig($id: ID!) { network(id: $id) { name } }"
cache.put(q, {"id": "1"}, {"name": "eth0"})
cache.put(q, {"id": "2"}, {"name": "eth1"})
assert cache.get(q, {"id": "1"}) == {"name": "eth0"}
assert cache.get(q, {"id": "2"}) == {"name": "eth1"}
def test_is_cacheable_returns_true_for_known_prefixes(self) -> None:
assert _QueryCache.is_cacheable("GetNetworkConfig { ... }") is True
assert _QueryCache.is_cacheable("GetRegistrationInfo { ... }") is True
assert _QueryCache.is_cacheable("GetOwner { ... }") is True
assert _QueryCache.is_cacheable("GetFlash { ... }") is True
def test_is_cacheable_returns_false_for_mutations(self) -> None:
assert _QueryCache.is_cacheable('mutation { docker { start(id: "x") } }') is False
def test_is_cacheable_returns_false_for_unlisted_queries(self) -> None:
assert _QueryCache.is_cacheable("{ docker { containers { id } } }") is False
assert _QueryCache.is_cacheable("{ info { os } }") is False
def test_is_cacheable_mutation_check_is_prefix(self) -> None:
"""Queries that start with 'mutation' after whitespace are not cacheable."""
assert _QueryCache.is_cacheable(" mutation { ... }") is False
def test_is_cacheable_with_explicit_query_keyword(self) -> None:
"""Operation names after explicit 'query' keyword must be recognized."""
assert _QueryCache.is_cacheable("query GetNetworkConfig { network { name } }") is True
assert _QueryCache.is_cacheable("query GetOwner { owner { name } }") is True
def test_is_cacheable_anonymous_query_returns_false(self) -> None:
"""Anonymous 'query { ... }' has no operation name — must not be cached."""
assert _QueryCache.is_cacheable("query { network { name } }") is False
def test_expired_entry_removed_from_store(self) -> None:
"""Accessing an expired entry should remove it from the internal store."""
cache = _QueryCache()
cache.put("GetOwner { }", None, {"owner": "root"})
key = cache._cache_key("GetOwner { }", None)
cache._store[key] = (time.monotonic() - 1.0, {"owner": "root"})
assert key in cache._store
cache.get("GetOwner { }", None) # triggers deletion
assert key not in cache._store
# ---------------------------------------------------------------------------
# make_graphql_request — 429 retry behavior
# ---------------------------------------------------------------------------
class TestRateLimitRetry:
"""Tests for the 429 retry loop in make_graphql_request."""
@pytest.fixture(autouse=True)
def _patch_config(self):
with (
patch("unraid_mcp.core.client.UNRAID_API_URL", "https://unraid.local/graphql"),
patch("unraid_mcp.core.client.UNRAID_API_KEY", "test-key"),
patch("unraid_mcp.core.client.asyncio.sleep", new_callable=AsyncMock),
):
yield
def _make_429_response(self) -> MagicMock:
resp = MagicMock()
resp.status_code = 429
resp.raise_for_status = MagicMock()
return resp
def _make_ok_response(self, data: dict) -> MagicMock:
resp = MagicMock()
resp.status_code = 200
resp.raise_for_status = MagicMock()
resp.json.return_value = {"data": data}
return resp
async def test_single_429_then_success_retries(self) -> None:
"""One 429 followed by a success should return the data."""
mock_client = AsyncMock()
mock_client.post.side_effect = [
self._make_429_response(),
self._make_ok_response({"info": {"os": "Unraid"}}),
]
with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client):
result = await make_graphql_request("{ info { os } }")
assert result == {"info": {"os": "Unraid"}}
assert mock_client.post.call_count == 2
async def test_two_429s_then_success(self) -> None:
"""Two 429s followed by success returns data after 2 retries."""
mock_client = AsyncMock()
mock_client.post.side_effect = [
self._make_429_response(),
self._make_429_response(),
self._make_ok_response({"x": 1}),
]
with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client):
result = await make_graphql_request("{ x }")
assert result == {"x": 1}
assert mock_client.post.call_count == 3
async def test_three_429s_raises_tool_error(self) -> None:
"""Three consecutive 429s (all retries exhausted) raises ToolError."""
mock_client = AsyncMock()
mock_client.post.side_effect = [
self._make_429_response(),
self._make_429_response(),
self._make_429_response(),
]
with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="rate limiting"),
):
await make_graphql_request("{ info }")
async def test_rate_limit_error_message_advises_wait(self) -> None:
"""The ToolError message should tell the user to wait ~10 seconds."""
mock_client = AsyncMock()
mock_client.post.side_effect = [
self._make_429_response(),
self._make_429_response(),
self._make_429_response(),
]
with (
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
pytest.raises(ToolError, match="10 seconds"),
):
await make_graphql_request("{ info }")

View File

@@ -80,14 +80,6 @@ class TestDockerValidation:
with pytest.raises(ToolError, match="network_id"): with pytest.raises(ToolError, match="network_id"):
await tool_fn(action="network_details") await tool_fn(action="network_details")
async def test_non_logs_action_ignores_tail_lines_validation(
self, _mock_graphql: AsyncMock
) -> None:
_mock_graphql.return_value = {"docker": {"containers": []}}
tool_fn = _make_tool()
result = await tool_fn(action="list", tail_lines=0)
assert result["containers"] == []
class TestDockerActions: class TestDockerActions:
async def test_list(self, _mock_graphql: AsyncMock) -> None: async def test_list(self, _mock_graphql: AsyncMock) -> None:
@@ -183,7 +175,7 @@ class TestDockerActions:
"docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]} "docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]}
} }
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="update_all", confirm=True) result = await tool_fn(action="update_all")
assert result["success"] is True assert result["success"] is True
assert len(result["containers"]) == 1 assert len(result["containers"]) == 1
@@ -232,22 +224,9 @@ class TestDockerActions:
async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("unexpected failure") _mock_graphql.side_effect = RuntimeError("unexpected failure")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="Failed to execute docker/list"): with pytest.raises(ToolError, match="unexpected failure"):
await tool_fn(action="list") await tool_fn(action="list")
async def test_short_id_prefix_ambiguous_rejected(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {
"docker": {
"containers": [
{"id": "abcdef1234560000000000000000000000000000000000000000000000000000:local", "names": ["plex"]},
{"id": "abcdef1234561111111111111111111111111111111111111111111111111111:local", "names": ["sonarr"]},
]
}
}
tool_fn = _make_tool()
with pytest.raises(ToolError, match="ambiguous"):
await tool_fn(action="logs", container_id="abcdef123456")
class TestDockerMutationFailures: class TestDockerMutationFailures:
"""Tests for mutation responses that indicate failure or unexpected shapes.""" """Tests for mutation responses that indicate failure or unexpected shapes."""
@@ -292,16 +271,10 @@ class TestDockerMutationFailures:
"""update_all with no containers to update.""" """update_all with no containers to update."""
_mock_graphql.return_value = {"docker": {"updateAllContainers": []}} _mock_graphql.return_value = {"docker": {"updateAllContainers": []}}
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="update_all", confirm=True) result = await tool_fn(action="update_all")
assert result["success"] is True assert result["success"] is True
assert result["containers"] == [] assert result["containers"] == []
async def test_update_all_requires_confirm(self, _mock_graphql: AsyncMock) -> None:
"""update_all is destructive and requires confirm=True."""
tool_fn = _make_tool()
with pytest.raises(ToolError, match="destructive"):
await tool_fn(action="update_all")
async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None: async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None:
"""Mid-operation timeout during a docker mutation.""" """Mid-operation timeout during a docker mutation."""

View File

@@ -7,7 +7,6 @@ import pytest
from conftest import make_tool_fn from conftest import make_tool_fn
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
from unraid_mcp.core.utils import safe_display_url
@pytest.fixture @pytest.fixture
@@ -100,7 +99,7 @@ class TestHealthActions:
"unraid_mcp.tools.health._diagnose_subscriptions", "unraid_mcp.tools.health._diagnose_subscriptions",
side_effect=RuntimeError("broken"), side_effect=RuntimeError("broken"),
), ),
pytest.raises(ToolError, match="Failed to execute health/diagnose"), pytest.raises(ToolError, match="broken"),
): ):
await tool_fn(action="diagnose") await tool_fn(action="diagnose")
@@ -115,7 +114,7 @@ class TestHealthActions:
assert "cpu_sub" in result assert "cpu_sub" in result
async def test_diagnose_import_error_internal(self) -> None: async def test_diagnose_import_error_internal(self) -> None:
"""_diagnose_subscriptions raises ToolError when subscription modules are unavailable.""" """_diagnose_subscriptions catches ImportError and returns error dict."""
import sys import sys
from unraid_mcp.tools.health import _diagnose_subscriptions from unraid_mcp.tools.health import _diagnose_subscriptions
@@ -127,70 +126,16 @@ class TestHealthActions:
try: try:
# Replace the modules with objects that raise ImportError on access # Replace the modules with objects that raise ImportError on access
with ( with patch.dict(
patch.dict(
sys.modules, sys.modules,
{ {
"unraid_mcp.subscriptions": None, "unraid_mcp.subscriptions": None,
"unraid_mcp.subscriptions.manager": None, "unraid_mcp.subscriptions.manager": None,
"unraid_mcp.subscriptions.resources": None, "unraid_mcp.subscriptions.resources": None,
}, },
),
pytest.raises(ToolError, match="Subscription modules not available"),
): ):
await _diagnose_subscriptions() result = await _diagnose_subscriptions()
assert "error" in result
finally: finally:
# Restore cached modules # Restore cached modules
sys.modules.update(cached) sys.modules.update(cached)
# ---------------------------------------------------------------------------
# _safe_display_url — URL redaction helper
# ---------------------------------------------------------------------------
class TestSafeDisplayUrl:
"""Verify that safe_display_url strips credentials/path and preserves scheme+host+port."""
def test_none_returns_none(self) -> None:
assert safe_display_url(None) is None
def test_empty_string_returns_none(self) -> None:
assert safe_display_url("") is None
def test_simple_url_scheme_and_host(self) -> None:
assert safe_display_url("https://unraid.local/graphql") == "https://unraid.local"
def test_preserves_port(self) -> None:
assert safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337"
def test_strips_path(self) -> None:
result = safe_display_url("http://unraid.local/some/deep/path?query=1")
assert "path" not in result
assert "query" not in result
def test_strips_credentials(self) -> None:
result = safe_display_url("https://user:password@unraid.local/graphql")
assert "user" not in result
assert "password" not in result
assert result == "https://unraid.local"
def test_strips_query_params(self) -> None:
result = safe_display_url("http://host.local?token=abc&key=xyz")
assert "token" not in result
assert "abc" not in result
def test_http_scheme_preserved(self) -> None:
result = safe_display_url("http://10.0.0.1:8080/api")
assert result == "http://10.0.0.1:8080"
def test_tailscale_url(self) -> None:
result = safe_display_url("https://100.118.209.1:31337/graphql")
assert result == "https://100.118.209.1:31337"
def test_malformed_ipv6_url_returns_unparseable(self) -> None:
"""Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError."""
# urlparse("https://[invalid") parses without error, but accessing .hostname
# raises ValueError: Invalid IPv6 URL — this triggers the except branch.
result = safe_display_url("https://[invalid")
assert result == "<unparseable>"

View File

@@ -186,7 +186,7 @@ class TestUnraidInfoTool:
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("unexpected") _mock_graphql.side_effect = RuntimeError("unexpected")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="Failed to execute info/online"): with pytest.raises(ToolError, match="unexpected"):
await tool_fn(action="online") await tool_fn(action="online")
async def test_metrics(self, _mock_graphql: AsyncMock) -> None: async def test_metrics(self, _mock_graphql: AsyncMock) -> None:
@@ -201,7 +201,6 @@ class TestUnraidInfoTool:
_mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]} _mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]}
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="services") result = await tool_fn(action="services")
assert "services" in result
assert len(result["services"]) == 1 assert len(result["services"]) == 1
assert result["services"][0]["name"] == "docker" assert result["services"][0]["name"] == "docker"
@@ -226,7 +225,6 @@ class TestUnraidInfoTool:
} }
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="servers") result = await tool_fn(action="servers")
assert "servers" in result
assert len(result["servers"]) == 1 assert len(result["servers"]) == 1
assert result["servers"][0]["name"] == "tower" assert result["servers"][0]["name"] == "tower"
@@ -250,7 +248,6 @@ class TestUnraidInfoTool:
} }
tool_fn = _make_tool() tool_fn = _make_tool()
result = await tool_fn(action="ups_devices") result = await tool_fn(action="ups_devices")
assert "ups_devices" in result
assert len(result["ups_devices"]) == 1 assert len(result["ups_devices"]) == 1
assert result["ups_devices"][0]["model"] == "APC" assert result["ups_devices"][0]["model"] == "APC"

View File

@@ -100,5 +100,5 @@ class TestKeysActions:
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("connection lost") _mock_graphql.side_effect = RuntimeError("connection lost")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="Failed to execute keys/list"): with pytest.raises(ToolError, match="connection lost"):
await tool_fn(action="list") await tool_fn(action="list")

View File

@@ -92,7 +92,7 @@ class TestNotificationsActions:
title="Test", title="Test",
subject="Test Subject", subject="Test Subject",
description="Test Desc", description="Test Desc",
importance="normal", importance="info",
) )
assert result["success"] is True assert result["success"] is True
@@ -149,89 +149,5 @@ class TestNotificationsActions:
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None: async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.side_effect = RuntimeError("boom") _mock_graphql.side_effect = RuntimeError("boom")
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="Failed to execute notifications/overview"): with pytest.raises(ToolError, match="boom"):
await tool_fn(action="overview") await tool_fn(action="overview")
class TestNotificationsCreateValidation:
"""Tests for importance enum and field length validation added in this PR."""
async def test_invalid_importance_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="importance must be one of"):
await tool_fn(
action="create",
title="T",
subject="S",
description="D",
importance="invalid",
)
async def test_info_importance_rejected(self, _mock_graphql: AsyncMock) -> None:
"""INFO is listed in old docstring examples but rejected by the validator."""
tool_fn = _make_tool()
with pytest.raises(ToolError, match="importance must be one of"):
await tool_fn(
action="create",
title="T",
subject="S",
description="D",
importance="info",
)
async def test_alert_importance_accepted(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {
"notifications": {"createNotification": {"id": "n:1", "importance": "ALERT"}}
}
tool_fn = _make_tool()
result = await tool_fn(
action="create", title="T", subject="S", description="D", importance="alert"
)
assert result["success"] is True
async def test_title_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="title must be at most 200"):
await tool_fn(
action="create",
title="x" * 201,
subject="S",
description="D",
importance="normal",
)
async def test_subject_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="subject must be at most 500"):
await tool_fn(
action="create",
title="T",
subject="x" * 501,
description="D",
importance="normal",
)
async def test_description_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="description must be at most 2000"):
await tool_fn(
action="create",
title="T",
subject="S",
description="x" * 2001,
importance="normal",
)
async def test_title_at_max_accepted(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {
"notifications": {"createNotification": {"id": "n:1", "importance": "NORMAL"}}
}
tool_fn = _make_tool()
result = await tool_fn(
action="create",
title="x" * 200,
subject="S",
description="D",
importance="normal",
)
assert result["success"] is True

View File

@@ -19,6 +19,7 @@ def _make_tool():
return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone") return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone")
@pytest.mark.usefixtures("_mock_graphql")
class TestRcloneValidation: class TestRcloneValidation:
async def test_delete_requires_confirm(self) -> None: async def test_delete_requires_confirm(self) -> None:
tool_fn = _make_tool() tool_fn = _make_tool()
@@ -99,83 +100,3 @@ class TestRcloneActions:
tool_fn = _make_tool() tool_fn = _make_tool()
with pytest.raises(ToolError, match="Failed to delete"): with pytest.raises(ToolError, match="Failed to delete"):
await tool_fn(action="delete_remote", name="gdrive", confirm=True) await tool_fn(action="delete_remote", name="gdrive", confirm=True)
class TestRcloneConfigDataValidation:
"""Tests for _validate_config_data security guards."""
async def test_path_traversal_in_key_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="disallowed characters"):
await tool_fn(
action="create_remote",
name="r",
provider_type="s3",
config_data={"../evil": "value"},
)
async def test_shell_metachar_in_key_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="disallowed characters"):
await tool_fn(
action="create_remote",
name="r",
provider_type="s3",
config_data={"key;rm": "value"},
)
async def test_too_many_keys_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="max 50"):
await tool_fn(
action="create_remote",
name="r",
provider_type="s3",
config_data={f"key{i}": "v" for i in range(51)},
)
async def test_dict_value_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="string, number, or boolean"):
await tool_fn(
action="create_remote",
name="r",
provider_type="s3",
config_data={"nested": {"key": "val"}},
)
async def test_value_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="exceeds max length"):
await tool_fn(
action="create_remote",
name="r",
provider_type="s3",
config_data={"key": "x" * 4097},
)
async def test_boolean_value_accepted(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {
"rclone": {"createRCloneRemote": {"name": "r", "type": "s3"}}
}
tool_fn = _make_tool()
result = await tool_fn(
action="create_remote",
name="r",
provider_type="s3",
config_data={"use_path_style": True},
)
assert result["success"] is True
async def test_int_value_accepted(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {
"rclone": {"createRCloneRemote": {"name": "r", "type": "sftp"}}
}
tool_fn = _make_tool()
result = await tool_fn(
action="create_remote",
name="r",
provider_type="sftp",
config_data={"port": 22},
)
assert result["success"] is True

View File

@@ -7,7 +7,7 @@ import pytest
from conftest import make_tool_fn from conftest import make_tool_fn
from unraid_mcp.core.exceptions import ToolError from unraid_mcp.core.exceptions import ToolError
from unraid_mcp.core.utils import format_bytes, format_kb, safe_get from unraid_mcp.tools.storage import format_bytes
# --- Unit tests for helpers --- # --- Unit tests for helpers ---
@@ -77,87 +77,6 @@ class TestStorageValidation:
result = await tool_fn(action="logs", log_path="/var/log/syslog") result = await tool_fn(action="logs", log_path="/var/log/syslog")
assert result["content"] == "ok" assert result["content"] == "ok"
async def test_logs_tail_lines_too_large(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="tail_lines must be between"):
await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_001)
async def test_logs_tail_lines_zero_rejected(self, _mock_graphql: AsyncMock) -> None:
tool_fn = _make_tool()
with pytest.raises(ToolError, match="tail_lines must be between"):
await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=0)
async def test_logs_tail_lines_at_max_accepted(self, _mock_graphql: AsyncMock) -> None:
_mock_graphql.return_value = {"logFile": {"path": "/var/log/syslog", "content": "ok"}}
tool_fn = _make_tool()
result = await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_000)
assert result["content"] == "ok"
async def test_non_logs_action_ignores_tail_lines_validation(
self, _mock_graphql: AsyncMock
) -> None:
_mock_graphql.return_value = {"shares": []}
tool_fn = _make_tool()
result = await tool_fn(action="shares", tail_lines=0)
assert result["shares"] == []
class TestFormatKb:
def test_none_returns_na(self) -> None:
assert format_kb(None) == "N/A"
def test_invalid_string_returns_na(self) -> None:
assert format_kb("not-a-number") == "N/A"
def test_kilobytes_range(self) -> None:
assert format_kb(512) == "512.00 KB"
def test_megabytes_range(self) -> None:
assert format_kb(2048) == "2.00 MB"
def test_gigabytes_range(self) -> None:
assert format_kb(1_048_576) == "1.00 GB"
def test_terabytes_range(self) -> None:
assert format_kb(1_073_741_824) == "1.00 TB"
def test_boundary_exactly_1024_kb(self) -> None:
# 1024 KB = 1 MB
assert format_kb(1024) == "1.00 MB"
class TestSafeGet:
def test_simple_key_access(self) -> None:
assert safe_get({"a": 1}, "a") == 1
def test_nested_key_access(self) -> None:
assert safe_get({"a": {"b": "val"}}, "a", "b") == "val"
def test_missing_key_returns_none(self) -> None:
assert safe_get({"a": 1}, "missing") is None
def test_none_intermediate_returns_default(self) -> None:
assert safe_get({"a": None}, "a", "b") is None
def test_custom_default_returned(self) -> None:
assert safe_get({}, "x", default="fallback") == "fallback"
def test_non_dict_intermediate_returns_default(self) -> None:
assert safe_get({"a": "string"}, "a", "b") is None
def test_empty_list_default(self) -> None:
result = safe_get({}, "missing", default=[])
assert result == []
def test_zero_value_not_replaced_by_default(self) -> None:
assert safe_get({"temp": 0}, "temp", default="N/A") == 0
def test_false_value_not_replaced_by_default(self) -> None:
assert safe_get({"active": False}, "active", default=True) is False
def test_empty_string_not_replaced_by_default(self) -> None:
assert safe_get({"name": ""}, "name", default="unknown") == ""
class TestStorageActions: class TestStorageActions:
async def test_shares(self, _mock_graphql: AsyncMock) -> None: async def test_shares(self, _mock_graphql: AsyncMock) -> None:

View File

@@ -1,156 +0,0 @@
"""Tests for _cap_log_content in subscriptions/manager.py.
_cap_log_content is a pure utility that prevents unbounded memory growth from
log subscription data. It must: return a NEW dict (not mutate), recursively
cap nested 'content' fields, and only truncate when both byte limit and line
limit are exceeded.
"""
from unittest.mock import patch
from unraid_mcp.subscriptions.manager import _cap_log_content
class TestCapLogContentImmutability:
"""The function must return a new dict — never mutate the input."""
def test_returns_new_dict(self) -> None:
data = {"key": "value"}
result = _cap_log_content(data)
assert result is not data
def test_input_not_mutated_on_passthrough(self) -> None:
data = {"content": "short text", "other": "value"}
original_content = data["content"]
_cap_log_content(data)
assert data["content"] == original_content
def test_input_not_mutated_on_truncation(self) -> None:
# Use small limits so the truncation path is exercised
large_content = "\n".join(f"line {i}" for i in range(200))
data = {"content": large_content}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
):
_cap_log_content(data)
# Original data must be unchanged
assert data["content"] == large_content
class TestCapLogContentSmallData:
"""Content below the byte limit must be returned unchanged."""
def test_small_content_unchanged(self) -> None:
data = {"content": "just a few lines\nof log data\n"}
result = _cap_log_content(data)
assert result["content"] == data["content"]
def test_non_content_keys_passed_through(self) -> None:
data = {"name": "cpu_subscription", "timestamp": "2026-02-18T00:00:00Z"}
result = _cap_log_content(data)
assert result == data
def test_integer_value_passed_through(self) -> None:
data = {"count": 42, "active": True}
result = _cap_log_content(data)
assert result == data
class TestCapLogContentTruncation:
"""Content exceeding both byte AND line limits must be truncated to the last N lines."""
def test_oversized_content_truncated_and_byte_capped(self) -> None:
# 200 lines, tiny byte limit: must keep recent content within byte cap.
lines = [f"line {i}" for i in range(200)]
data = {"content": "\n".join(lines)}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
):
result = _cap_log_content(data)
result_lines = result["content"].splitlines()
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
# Must keep the most recent line suffix.
assert result_lines[-1] == "line 199"
def test_content_with_fewer_lines_than_limit_still_honors_byte_cap(self) -> None:
"""If byte limit is exceeded, output must still be capped even with few lines."""
# 30 lines, byte limit 10, line limit 50 -> must cap bytes regardless of line count
lines = [f"line {i}" for i in range(30)]
data = {"content": "\n".join(lines)}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
):
result = _cap_log_content(data)
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
def test_non_content_keys_preserved_alongside_truncated_content(self) -> None:
lines = [f"line {i}" for i in range(200)]
data = {"content": "\n".join(lines), "path": "/var/log/syslog", "total_lines": 200}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
):
result = _cap_log_content(data)
assert result["path"] == "/var/log/syslog"
assert result["total_lines"] == 200
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
class TestCapLogContentNested:
"""Nested 'content' fields inside sub-dicts must also be capped recursively."""
def test_nested_content_field_capped(self) -> None:
lines = [f"line {i}" for i in range(200)]
data = {"logFile": {"content": "\n".join(lines), "path": "/var/log/syslog"}}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
):
result = _cap_log_content(data)
assert len(result["logFile"]["content"].encode("utf-8", errors="replace")) <= 10
assert result["logFile"]["path"] == "/var/log/syslog"
def test_deeply_nested_content_capped(self) -> None:
lines = [f"line {i}" for i in range(200)]
data = {"outer": {"inner": {"content": "\n".join(lines)}}}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
):
result = _cap_log_content(data)
assert len(result["outer"]["inner"]["content"].encode("utf-8", errors="replace")) <= 10
def test_nested_non_content_keys_unaffected(self) -> None:
data = {"metrics": {"cpu": 42.5, "memory": 8192}}
result = _cap_log_content(data)
assert result == data
class TestCapLogContentSingleMassiveLine:
"""A single line larger than the byte cap must be hard-capped at byte level."""
def test_single_massive_line_hard_caps_bytes(self) -> None:
# One line, no newlines, larger than the byte cap.
# The while-loop can't reduce it (len(lines) == 1), so the
# last-resort byte-slice path at manager.py:65-69 must fire.
huge_content = "x" * 200
data = {"content": huge_content}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
):
result = _cap_log_content(data)
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
def test_single_massive_line_input_not_mutated(self) -> None:
huge_content = "x" * 200
data = {"content": huge_content}
with (
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
):
_cap_log_content(data)
assert data["content"] == huge_content

View File

@@ -1,131 +0,0 @@
"""Tests for _validate_subscription_query in diagnostics.py.
Security-critical: this function is the only guard against arbitrary GraphQL
operations (mutations, queries) being sent over the WebSocket subscription channel.
"""
import pytest
from unraid_mcp.core.exceptions import ToolError
from unraid_mcp.subscriptions.diagnostics import (
_ALLOWED_SUBSCRIPTION_NAMES,
_validate_subscription_query,
)
class TestValidateSubscriptionQueryAllowed:
"""All whitelisted subscription names must be accepted."""
@pytest.mark.parametrize("sub_name", sorted(_ALLOWED_SUBSCRIPTION_NAMES))
def test_all_allowed_names_accepted(self, sub_name: str) -> None:
query = f"subscription {{ {sub_name} {{ data }} }}"
result = _validate_subscription_query(query)
assert result == sub_name
def test_returns_extracted_subscription_name(self) -> None:
query = "subscription { cpuSubscription { usage } }"
assert _validate_subscription_query(query) == "cpuSubscription"
def test_leading_whitespace_accepted(self) -> None:
query = " subscription { memorySubscription { free } }"
assert _validate_subscription_query(query) == "memorySubscription"
def test_multiline_query_accepted(self) -> None:
query = "subscription {\n logFileSubscription {\n content\n }\n}"
assert _validate_subscription_query(query) == "logFileSubscription"
def test_case_insensitive_subscription_keyword(self) -> None:
"""'SUBSCRIPTION' should be accepted (regex uses IGNORECASE)."""
query = "SUBSCRIPTION { cpuSubscription { usage } }"
assert _validate_subscription_query(query) == "cpuSubscription"
class TestValidateSubscriptionQueryForbiddenKeywords:
"""Queries containing 'mutation' or 'query' as standalone keywords must be rejected."""
def test_mutation_keyword_rejected(self) -> None:
query = 'mutation { docker { start(id: "abc") } }'
with pytest.raises(ToolError, match="must be a subscription"):
_validate_subscription_query(query)
def test_query_keyword_rejected(self) -> None:
query = "query { info { os { platform } } }"
with pytest.raises(ToolError, match="must be a subscription"):
_validate_subscription_query(query)
def test_mutation_embedded_in_subscription_rejected(self) -> None:
"""'mutation' anywhere in the string triggers rejection."""
query = "subscription { cpuSubscription { mutation data } }"
with pytest.raises(ToolError, match="must be a subscription"):
_validate_subscription_query(query)
def test_query_embedded_in_subscription_rejected(self) -> None:
query = "subscription { cpuSubscription { query data } }"
with pytest.raises(ToolError, match="must be a subscription"):
_validate_subscription_query(query)
def test_mutation_case_insensitive_rejection(self) -> None:
query = 'MUTATION { docker { start(id: "abc") } }'
with pytest.raises(ToolError, match="must be a subscription"):
_validate_subscription_query(query)
def test_mutation_field_identifier_not_rejected(self) -> None:
"""'mutationField' as an identifier must NOT be rejected — only standalone 'mutation'."""
# This tests the \b word boundary in _FORBIDDEN_KEYWORDS
query = "subscription { cpuSubscription { mutationField } }"
# Should not raise — "mutationField" is an identifier, not the keyword
result = _validate_subscription_query(query)
assert result == "cpuSubscription"
def test_query_field_identifier_not_rejected(self) -> None:
"""'queryResult' as an identifier must NOT be rejected."""
query = "subscription { cpuSubscription { queryResult } }"
result = _validate_subscription_query(query)
assert result == "cpuSubscription"
class TestValidateSubscriptionQueryInvalidFormat:
"""Queries that don't match the expected subscription format must be rejected."""
def test_empty_string_rejected(self) -> None:
with pytest.raises(ToolError, match="must start with 'subscription'"):
_validate_subscription_query("")
def test_plain_identifier_rejected(self) -> None:
with pytest.raises(ToolError, match="must start with 'subscription'"):
_validate_subscription_query("cpuSubscription { usage }")
def test_missing_operation_body_rejected(self) -> None:
with pytest.raises(ToolError, match="must start with 'subscription'"):
_validate_subscription_query("subscription")
def test_subscription_without_field_rejected(self) -> None:
"""subscription { } with no field name doesn't match the pattern."""
with pytest.raises(ToolError, match="must start with 'subscription'"):
_validate_subscription_query("subscription { }")
class TestValidateSubscriptionQueryUnknownName:
"""Subscription names not in the whitelist must be rejected even if format is valid."""
def test_unknown_subscription_name_rejected(self) -> None:
query = "subscription { unknownSubscription { data } }"
with pytest.raises(ToolError, match="not allowed"):
_validate_subscription_query(query)
def test_error_message_includes_allowed_list(self) -> None:
"""Error message must list the allowed subscription names for usability."""
query = "subscription { badSub { data } }"
with pytest.raises(ToolError, match="Allowed subscriptions"):
_validate_subscription_query(query)
def test_arbitrary_field_name_rejected(self) -> None:
query = "subscription { users { id email } }"
with pytest.raises(ToolError, match="not allowed"):
_validate_subscription_query(query)
def test_close_but_not_whitelisted_rejected(self) -> None:
"""'cpu' without 'Subscription' suffix is not in the allow-list."""
query = "subscription { cpu { usage } }"
with pytest.raises(ToolError, match="not allowed"):
_validate_subscription_query(query)

View File

@@ -1,6 +1,7 @@
"""Unraid MCP Server Package.""" """Unraid MCP Server Package.
from .version import VERSION A modular MCP (Model Context Protocol) server that provides tools to interact
with an Unraid server's GraphQL API.
"""
__version__ = "0.2.0"
__version__ = VERSION

View File

@@ -5,10 +5,16 @@ that cap at 10MB and start over (no rotation) for consistent use across all modu
""" """
import logging import logging
from datetime import datetime
from pathlib import Path from pathlib import Path
import pytz
from rich.align import Align
from rich.console import Console from rich.console import Console
from rich.logging import RichHandler from rich.logging import RichHandler
from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text
try: try:
@@ -22,7 +28,7 @@ from .settings import LOG_FILE_PATH, LOG_LEVEL_STR
# Global Rich console for consistent formatting # Global Rich console for consistent formatting
console = Console(stderr=True) console = Console(stderr=True, force_terminal=True)
class OverwriteFileHandler(logging.FileHandler): class OverwriteFileHandler(logging.FileHandler):
@@ -39,18 +45,12 @@ class OverwriteFileHandler(logging.FileHandler):
delay: Whether to delay file opening delay: Whether to delay file opening
""" """
self.max_bytes = max_bytes self.max_bytes = max_bytes
self._emit_count = 0
self._check_interval = 100
super().__init__(filename, mode, encoding, delay) super().__init__(filename, mode, encoding, delay)
def emit(self, record): def emit(self, record):
"""Emit a record, checking file size periodically and overwriting if needed.""" """Emit a record, checking file size and overwriting if needed."""
self._emit_count += 1 # Check file size before writing
if ( if self.stream and hasattr(self.stream, "name"):
(self._emit_count == 1 or self._emit_count % self._check_interval == 0)
and self.stream
and hasattr(self.stream, "name")
):
try: try:
base_path = Path(self.baseFilename) base_path = Path(self.baseFilename)
if base_path.exists(): if base_path.exists():
@@ -91,28 +91,6 @@ class OverwriteFileHandler(logging.FileHandler):
super().emit(record) super().emit(record)
def _create_shared_file_handler() -> OverwriteFileHandler:
"""Create the single shared file handler for all loggers.
Returns:
Configured OverwriteFileHandler instance
"""
numeric_log_level = getattr(logging, LOG_LEVEL_STR, logging.INFO)
handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
handler.setLevel(numeric_log_level)
handler.setFormatter(
logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
)
)
return handler
# Single shared file handler — all loggers reuse this instance to avoid
# race conditions from multiple OverwriteFileHandler instances on the same file.
_shared_file_handler = _create_shared_file_handler()
def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger: def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
"""Set up and configure the logger with console and file handlers. """Set up and configure the logger with console and file handlers.
@@ -140,13 +118,19 @@ def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
show_level=True, show_level=True,
show_path=False, show_path=False,
rich_tracebacks=True, rich_tracebacks=True,
tracebacks_show_locals=False, tracebacks_show_locals=True,
) )
console_handler.setLevel(numeric_log_level) console_handler.setLevel(numeric_log_level)
logger.addHandler(console_handler) logger.addHandler(console_handler)
# Reuse the shared file handler # File Handler with 10MB cap (overwrites instead of rotating)
logger.addHandler(_shared_file_handler) file_handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
file_handler.setLevel(numeric_log_level)
file_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
return logger return logger
@@ -173,14 +157,20 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
show_level=True, show_level=True,
show_path=False, show_path=False,
rich_tracebacks=True, rich_tracebacks=True,
tracebacks_show_locals=False, tracebacks_show_locals=True,
markup=True, markup=True,
) )
console_handler.setLevel(numeric_log_level) console_handler.setLevel(numeric_log_level)
fastmcp_logger.addHandler(console_handler) fastmcp_logger.addHandler(console_handler)
# Reuse the shared file handler # File Handler with 10MB cap (overwrites instead of rotating)
fastmcp_logger.addHandler(_shared_file_handler) file_handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
file_handler.setLevel(numeric_log_level)
file_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
)
file_handler.setFormatter(file_formatter)
fastmcp_logger.addHandler(file_handler)
fastmcp_logger.setLevel(numeric_log_level) fastmcp_logger.setLevel(numeric_log_level)
@@ -196,19 +186,30 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
show_level=True, show_level=True,
show_path=False, show_path=False,
rich_tracebacks=True, rich_tracebacks=True,
tracebacks_show_locals=False, tracebacks_show_locals=True,
markup=True, markup=True,
) )
root_console_handler.setLevel(numeric_log_level) root_console_handler.setLevel(numeric_log_level)
root_logger.addHandler(root_console_handler) root_logger.addHandler(root_console_handler)
# Reuse the shared file handler for root logger # File Handler for root logger with 10MB cap (overwrites instead of rotating)
root_logger.addHandler(_shared_file_handler) root_file_handler = OverwriteFileHandler(
LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8"
)
root_file_handler.setLevel(numeric_log_level)
root_file_handler.setFormatter(file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(numeric_log_level) root_logger.setLevel(numeric_log_level)
return fastmcp_logger return fastmcp_logger
def setup_uvicorn_logging() -> logging.Logger | None:
"""Configure uvicorn and other third-party loggers to use Rich formatting."""
# This function is kept for backward compatibility but now delegates to FastMCP
return configure_fastmcp_logger_with_rich()
def log_configuration_status(logger: logging.Logger) -> None: def log_configuration_status(logger: logging.Logger) -> None:
"""Log configuration status at startup. """Log configuration status at startup.
@@ -241,6 +242,97 @@ def log_configuration_status(logger: logging.Logger) -> None:
logger.error(f"Missing required configuration: {config['missing_config']}") logger.error(f"Missing required configuration: {config['missing_config']}")
# Development logging helpers for Rich formatting
def get_est_timestamp() -> str:
"""Get current timestamp in EST timezone with YY/MM/DD format."""
est = pytz.timezone("US/Eastern")
now = datetime.now(est)
return now.strftime("%y/%m/%d %H:%M:%S")
def log_header(title: str) -> None:
"""Print a beautiful header panel with Nordic blue styling."""
panel = Panel(
Align.center(Text(title, style="bold white")),
style="#5E81AC", # Nordic blue
padding=(0, 2),
border_style="#81A1C1", # Light Nordic blue
)
console.print(panel)
def log_with_level_and_indent(message: str, level: str = "info", indent: int = 0) -> None:
"""Log a message with specific level and indentation."""
timestamp = get_est_timestamp()
indent_str = " " * indent
# Enhanced Nordic color scheme with more blues
level_config = {
"error": {"color": "#BF616A", "icon": "", "style": "bold"}, # Nordic red
"warning": {"color": "#EBCB8B", "icon": "⚠️", "style": ""}, # Nordic yellow
"success": {"color": "#A3BE8C", "icon": "", "style": "bold"}, # Nordic green
"info": {"color": "#5E81AC", "icon": "\u2139\ufe0f", "style": "bold"}, # Nordic blue (bold)
"status": {"color": "#81A1C1", "icon": "🔍", "style": ""}, # Light Nordic blue
"debug": {"color": "#4C566A", "icon": "🐛", "style": ""}, # Nordic dark gray
}
config = level_config.get(
level, {"color": "#81A1C1", "icon": "", "style": ""}
) # Default to light Nordic blue
# Create beautifully formatted text
text = Text()
# Timestamp with Nordic blue styling
text.append(f"[{timestamp}]", style="#81A1C1") # Light Nordic blue for timestamps
text.append(" ")
# Indentation with Nordic blue styling
if indent > 0:
text.append(indent_str, style="#81A1C1")
# Level icon (only for certain levels)
if level in ["error", "warning", "success"]:
# Extract emoji from message if it starts with one, to avoid duplication
if message and len(message) > 0 and ord(message[0]) >= 0x1F600: # Emoji range
# Message already has emoji, don't add icon
pass
else:
text.append(f"{config['icon']} ", style=config["color"])
# Message content
message_style = f"{config['color']} {config['style']}".strip()
text.append(message, style=message_style)
console.print(text)
def log_separator() -> None:
"""Print a beautiful separator line with Nordic blue styling."""
console.print(Rule(style="#81A1C1"))
# Convenience functions for different log levels
def log_error(message: str, indent: int = 0) -> None:
log_with_level_and_indent(message, "error", indent)
def log_warning(message: str, indent: int = 0) -> None:
log_with_level_and_indent(message, "warning", indent)
def log_success(message: str, indent: int = 0) -> None:
log_with_level_and_indent(message, "success", indent)
def log_info(message: str, indent: int = 0) -> None:
log_with_level_and_indent(message, "info", indent)
def log_status(message: str, indent: int = 0) -> None:
log_with_level_and_indent(message, "status", indent)
# Global logger instance - modules can import this directly # Global logger instance - modules can import this directly
if FASTMCP_AVAILABLE: if FASTMCP_AVAILABLE:
# Use FastMCP logger with Rich formatting # Use FastMCP logger with Rich formatting
@@ -249,3 +341,5 @@ if FASTMCP_AVAILABLE:
else: else:
# Fallback to our custom logger if FastMCP is not available # Fallback to our custom logger if FastMCP is not available
logger = setup_logger() logger = setup_logger()
# Setup uvicorn logging when module is imported
setup_uvicorn_logging()

View File

@@ -10,8 +10,6 @@ from typing import Any
from dotenv import load_dotenv from dotenv import load_dotenv
from ..version import VERSION as APP_VERSION
# Get the script directory (config module location) # Get the script directory (config module location)
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/ SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
@@ -32,13 +30,16 @@ for dotenv_path in dotenv_paths:
load_dotenv(dotenv_path=dotenv_path) load_dotenv(dotenv_path=dotenv_path)
break break
# Application Version
VERSION = "0.2.0"
# Core API Configuration # Core API Configuration
UNRAID_API_URL = os.getenv("UNRAID_API_URL") UNRAID_API_URL = os.getenv("UNRAID_API_URL")
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY") UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
# Server Configuration # Server Configuration
UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970")) UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970"))
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0") # noqa: S104 — intentional for Docker UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0")
UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower() UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower()
# SSL Configuration # SSL Configuration
@@ -53,18 +54,11 @@ else: # Path to CA bundle
# Logging Configuration # Logging Configuration
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper() LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log") LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
# Use /.dockerenv as the container indicator for robust Docker detection. LOGS_DIR = Path("/tmp")
IS_DOCKER = Path("/.dockerenv").exists()
LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs"
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
# Ensure logs directory exists; if creation fails, fall back to /tmp. # Ensure logs directory exists
try:
LOGS_DIR.mkdir(parents=True, exist_ok=True) LOGS_DIR.mkdir(parents=True, exist_ok=True)
except OSError:
LOGS_DIR = PROJECT_ROOT / ".cache" / "logs"
LOGS_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
# HTTP Client Configuration # HTTP Client Configuration
TIMEOUT_CONFIG = { TIMEOUT_CONFIG = {
@@ -110,5 +104,3 @@ def get_config_summary() -> dict[str, Any]:
"config_valid": is_valid, "config_valid": is_valid,
"missing_config": missing if not is_valid else None, "missing_config": missing if not is_valid else None,
} }
# Re-export application version from a single source of truth.
VERSION = APP_VERSION

View File

@@ -5,11 +5,8 @@ to the Unraid API with proper timeout handling and error management.
""" """
import asyncio import asyncio
import hashlib
import json import json
import re from typing import Any
import time
from typing import Any, Final
import httpx import httpx
@@ -24,22 +21,8 @@ from ..config.settings import (
from ..core.exceptions import ToolError from ..core.exceptions import ToolError
# Sensitive keys to redact from debug logs (frozenset — immutable, Final — no accidental reassignment) # Sensitive keys to redact from debug logs
_SENSITIVE_KEYS: Final[frozenset[str]] = frozenset( _SENSITIVE_KEYS = {"password", "key", "secret", "token", "apikey"}
{
"password",
"key",
"secret",
"token",
"apikey",
"authorization",
"cookie",
"session",
"credential",
"passphrase",
"jwt",
}
)
def _is_sensitive_key(key: str) -> bool: def _is_sensitive_key(key: str) -> bool:
@@ -48,14 +31,14 @@ def _is_sensitive_key(key: str) -> bool:
return any(s in key_lower for s in _SENSITIVE_KEYS) return any(s in key_lower for s in _SENSITIVE_KEYS)
def redact_sensitive(obj: Any) -> Any: def _redact_sensitive(obj: Any) -> Any:
"""Recursively redact sensitive values from nested dicts/lists.""" """Recursively redact sensitive values from nested dicts/lists."""
if isinstance(obj, dict): if isinstance(obj, dict):
return { return {
k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items() k: ("***" if _is_sensitive_key(k) else _redact_sensitive(v)) for k, v in obj.items()
} }
if isinstance(obj, list): if isinstance(obj, list):
return [redact_sensitive(item) for item in obj] return [_redact_sensitive(item) for item in obj]
return obj return obj
@@ -83,116 +66,8 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout:
# Global connection pool (module-level singleton) # Global connection pool (module-level singleton)
# Python 3.12+ asyncio.Lock() is safe at module level — no running event loop required
_http_client: httpx.AsyncClient | None = None _http_client: httpx.AsyncClient | None = None
_client_lock: Final[asyncio.Lock] = asyncio.Lock() _client_lock = asyncio.Lock()
class _RateLimiter:
"""Token bucket rate limiter for Unraid API (100 req / 10s hard limit).
Uses 90 tokens with 9.0 tokens/sec refill for 10% safety headroom.
"""
def __init__(self, max_tokens: int = 90, refill_rate: float = 9.0) -> None:
self.max_tokens = max_tokens
self.tokens = float(max_tokens)
self.refill_rate = refill_rate # tokens per second
self.last_refill = time.monotonic()
# asyncio.Lock() is safe to create at __init__ time (Python 3.12+)
self._lock: Final[asyncio.Lock] = asyncio.Lock()
def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
now = time.monotonic()
elapsed = now - self.last_refill
self.tokens = min(self.max_tokens, self.tokens + elapsed * self.refill_rate)
self.last_refill = now
async def acquire(self) -> None:
"""Consume one token, waiting if necessary for refill."""
while True:
async with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
return
wait_time = (1 - self.tokens) / self.refill_rate
# Sleep outside the lock so other coroutines aren't blocked
await asyncio.sleep(wait_time)
_rate_limiter = _RateLimiter()
# --- TTL Cache for stable read-only queries ---
# Queries whose results change infrequently and are safe to cache.
# Mutations and volatile queries (metrics, docker, array state) are excluded.
_CACHEABLE_QUERY_PREFIXES = frozenset(
{
"GetNetworkConfig",
"GetRegistrationInfo",
"GetOwner",
"GetFlash",
}
)
_CACHE_TTL_SECONDS = 60.0
_OPERATION_NAME_PATTERN = re.compile(r"^(?:query\s+)?([_A-Za-z][_0-9A-Za-z]*)\b")
class _QueryCache:
"""Simple TTL cache for GraphQL query responses.
Keyed by a hash of (query, variables). Entries expire after _CACHE_TTL_SECONDS.
Only caches responses for queries whose operation name is in _CACHEABLE_QUERY_PREFIXES.
Mutation requests always bypass the cache.
"""
def __init__(self) -> None:
self._store: dict[str, tuple[float, dict[str, Any]]] = {}
@staticmethod
def _cache_key(query: str, variables: dict[str, Any] | None) -> str:
raw = query + json.dumps(variables or {}, sort_keys=True)
return hashlib.sha256(raw.encode()).hexdigest()
@staticmethod
def is_cacheable(query: str) -> bool:
"""Check if a query is eligible for caching based on its operation name."""
normalized = query.lstrip()
if normalized.startswith("mutation"):
return False
match = _OPERATION_NAME_PATTERN.match(normalized)
if not match:
return False
return match.group(1) in _CACHEABLE_QUERY_PREFIXES
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
"""Return cached result if present and not expired, else None."""
key = self._cache_key(query, variables)
entry = self._store.get(key)
if entry is None:
return None
expires_at, data = entry
if time.monotonic() > expires_at:
del self._store[key]
return None
return data
def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
"""Store a query result with TTL expiry."""
key = self._cache_key(query, variables)
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
def invalidate_all(self) -> None:
"""Clear the entire cache (called after mutations)."""
self._store.clear()
_query_cache = _QueryCache()
def is_idempotent_error(error_message: str, operation: str) -> bool: def is_idempotent_error(error_message: str, operation: str) -> bool:
@@ -234,7 +109,7 @@ async def _create_http_client() -> httpx.AsyncClient:
return httpx.AsyncClient( return httpx.AsyncClient(
# Connection pool settings # Connection pool settings
limits=httpx.Limits( limits=httpx.Limits(
max_keepalive_connections=20, max_connections=20, keepalive_expiry=30.0 max_keepalive_connections=20, max_connections=100, keepalive_expiry=30.0
), ),
# Default timeout (can be overridden per-request) # Default timeout (can be overridden per-request)
timeout=DEFAULT_TIMEOUT, timeout=DEFAULT_TIMEOUT,
@@ -248,28 +123,33 @@ async def _create_http_client() -> httpx.AsyncClient:
async def get_http_client() -> httpx.AsyncClient: async def get_http_client() -> httpx.AsyncClient:
"""Get or create shared HTTP client with connection pooling. """Get or create shared HTTP client with connection pooling.
Uses double-checked locking: fast-path skips the lock when the client The client is protected by an asyncio lock to prevent concurrent creation.
is already initialized, only acquiring it for initial creation or If the existing client was closed (e.g., during shutdown), a new one is created.
recovery after close.
Returns: Returns:
Singleton AsyncClient instance with connection pooling enabled Singleton AsyncClient instance with connection pooling enabled
""" """
global _http_client global _http_client
# Fast-path: skip lock if client is already initialized and open
client = _http_client
if client is not None and not client.is_closed:
return client
# Slow-path: acquire lock for initialization
async with _client_lock: async with _client_lock:
if _http_client is None or _http_client.is_closed: if _http_client is None or _http_client.is_closed:
_http_client = await _create_http_client() _http_client = await _create_http_client()
logger.info( logger.info(
"Created shared HTTP client with connection pooling (20 keepalive, 20 max connections)" "Created shared HTTP client with connection pooling (20 keepalive, 100 max connections)"
) )
return _http_client
client = _http_client
# Verify client is still open after releasing the lock.
# In asyncio's cooperative model this is unlikely to fail, but guards
# against edge cases where close_http_client runs between yield points.
if client.is_closed:
async with _client_lock:
_http_client = await _create_http_client()
client = _http_client
logger.info("Re-created HTTP client after unexpected close")
return client
async def close_http_client() -> None: async def close_http_client() -> None:
@@ -310,14 +190,6 @@ async def make_graphql_request(
if not UNRAID_API_KEY: if not UNRAID_API_KEY:
raise ToolError("UNRAID_API_KEY not configured") raise ToolError("UNRAID_API_KEY not configured")
# Check TTL cache for stable read-only queries
is_mutation = query.lstrip().startswith("mutation")
if not is_mutation and _query_cache.is_cacheable(query):
cached = _query_cache.get(query, variables)
if cached is not None:
logger.debug("Returning cached response for query")
return cached
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"X-API-Key": UNRAID_API_KEY, "X-API-Key": UNRAID_API_KEY,
@@ -330,41 +202,19 @@ async def make_graphql_request(
logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:") logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:")
logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query
if variables: if variables:
logger.debug(f"Variables: {redact_sensitive(variables)}") logger.debug(f"Variables: {_redact_sensitive(variables)}")
try: try:
# Rate limit: consume a token before making the request
await _rate_limiter.acquire()
# Get the shared HTTP client with connection pooling # Get the shared HTTP client with connection pooling
client = await get_http_client() client = await get_http_client()
# Retry loop for 429 rate limit responses # Override timeout if custom timeout specified
post_kwargs: dict[str, Any] = {"json": payload, "headers": headers}
if custom_timeout is not None: if custom_timeout is not None:
post_kwargs["timeout"] = custom_timeout response = await client.post(
UNRAID_API_URL, json=payload, headers=headers, timeout=custom_timeout
response: httpx.Response | None = None
for attempt in range(3):
response = await client.post(UNRAID_API_URL, **post_kwargs)
if response.status_code == 429:
backoff = 2**attempt
logger.warning(
f"Rate limited (429) by Unraid API, retrying in {backoff}s (attempt {attempt + 1}/3)"
)
await asyncio.sleep(backoff)
continue
break
if response is None: # pragma: no cover — guaranteed by loop
raise ToolError("No response received after retry attempts")
# Provide a clear message when all retries are exhausted on 429
if response.status_code == 429:
logger.error("Rate limit (429) persisted after 3 retries — request aborted")
raise ToolError(
"Unraid API is rate limiting requests. Wait ~10 seconds before retrying."
) )
else:
response = await client.post(UNRAID_API_URL, json=payload, headers=headers)
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
@@ -395,27 +245,14 @@ async def make_graphql_request(
logger.debug("GraphQL request successful.") logger.debug("GraphQL request successful.")
data = response_data.get("data", {}) data = response_data.get("data", {})
result = data if isinstance(data, dict) else {} # Ensure we return dict return data if isinstance(data, dict) else {} # Ensure we return dict
# Invalidate cache on mutations; cache eligible query results
if is_mutation:
_query_cache.invalidate_all()
elif _query_cache.is_cacheable(query):
_query_cache.put(query, variables, result)
return result
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
# Log full details internally; only expose status code to MCP client
logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
raise ToolError( raise ToolError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
f"Unraid API returned HTTP {e.response.status_code}. Check server logs for details."
) from e
except httpx.RequestError as e: except httpx.RequestError as e:
# Log full error internally; give safe summary to MCP client
logger.error(f"Request error occurred: {e}") logger.error(f"Request error occurred: {e}")
raise ToolError(f"Network error connecting to Unraid API: {type(e).__name__}") from e raise ToolError(f"Network connection error: {e!s}") from e
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Log full decode error; give safe summary to MCP client
logger.error(f"Failed to decode JSON response: {e}") logger.error(f"Failed to decode JSON response: {e}")
raise ToolError("Unraid API returned an invalid response (not valid JSON)") from e raise ToolError(f"Invalid JSON response from Unraid API: {e!s}") from e

View File

@@ -4,10 +4,6 @@ This module defines custom exception classes for consistent error handling
throughout the application, with proper integration to FastMCP's error system. throughout the application, with proper integration to FastMCP's error system.
""" """
import contextlib
import logging
from collections.abc import Iterator
from fastmcp.exceptions import ToolError as FastMCPToolError from fastmcp.exceptions import ToolError as FastMCPToolError
@@ -23,34 +19,36 @@ class ToolError(FastMCPToolError):
pass pass
@contextlib.contextmanager class ConfigurationError(ToolError):
def tool_error_handler( """Raised when there are configuration-related errors."""
tool_name: str,
action: str,
logger: logging.Logger,
) -> Iterator[None]:
"""Context manager that standardizes tool error handling.
Re-raises ToolError as-is. Gives TimeoutError a descriptive message. pass
Catches all other exceptions, logs them with full traceback, and wraps them
in ToolError with a descriptive message.
Args:
tool_name: The tool name for error messages (e.g., "docker", "vm"). class UnraidAPIError(ToolError):
action: The current action being executed. """Raised when the Unraid API returns an error or is unreachable."""
logger: The logger instance to use for error logging.
pass
class SubscriptionError(ToolError):
"""Raised when there are WebSocket subscription-related errors."""
pass
class ValidationError(ToolError):
"""Raised when input validation fails."""
pass
class IdempotentOperationError(ToolError):
"""Raised when an operation is idempotent (already in desired state).
This is used internally to signal that an operation was already complete,
which should typically be converted to a success response rather than
propagated as an error to the user.
""" """
try:
yield pass
except ToolError:
raise
except TimeoutError as e:
logger.exception(f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit")
raise ToolError(
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
) from e
except Exception as e:
logger.exception(f"Error in unraid_{tool_name} action={action}")
raise ToolError(
f"Failed to execute {tool_name}/{action}. Check server logs for details."
) from e

View File

@@ -9,21 +9,38 @@ from datetime import datetime
from typing import Any from typing import Any
@dataclass(slots=True) @dataclass
class SubscriptionData: class SubscriptionData:
"""Container for subscription data with metadata. """Container for subscription data with metadata."""
Note: last_updated must be timezone-aware (use datetime.now(UTC)).
"""
data: dict[str, Any] data: dict[str, Any]
last_updated: datetime # Must be timezone-aware (UTC) last_updated: datetime
subscription_type: str subscription_type: str
def __post_init__(self) -> None:
if self.last_updated.tzinfo is None: @dataclass
raise ValueError( class SystemHealth:
"last_updated must be timezone-aware; use datetime.now(UTC)" """Container for system health status information."""
)
if not self.subscription_type.strip(): is_healthy: bool
raise ValueError("subscription_type must be a non-empty string") issues: list[str]
warnings: list[str]
last_checked: datetime
component_status: dict[str, str]
@dataclass
class APIResponse:
"""Container for standardized API response data."""
success: bool
data: dict[str, Any] | None = None
error: str | None = None
metadata: dict[str, Any] | None = None
# Type aliases for common data structures
ConfigValue = str | int | bool | float | None
ConfigDict = dict[str, ConfigValue]
GraphQLVariables = dict[str, Any]
HealthStatus = dict[str, str | bool | int | list[Any]]

View File

@@ -1,89 +0,0 @@
"""Shared utility functions for Unraid MCP tools."""
from typing import Any
from urllib.parse import urlparse
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
"""Safely traverse nested dict keys, handling None intermediates.
Args:
data: The root dictionary to traverse.
*keys: Sequence of keys to follow.
default: Value to return if any key is missing or None.
Returns:
The value at the end of the key chain, or default if unreachable.
Explicit ``None`` values at the final key also return ``default``.
"""
current = data
for key in keys:
if not isinstance(current, dict):
return default
current = current.get(key)
return current if current is not None else default
def format_bytes(bytes_value: int | None) -> str:
"""Format byte values into human-readable sizes.
Args:
bytes_value: Number of bytes, or None.
Returns:
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
"""
if bytes_value is None:
return "N/A"
try:
value = float(int(bytes_value))
except (ValueError, TypeError):
return "N/A"
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if value < 1024.0:
return f"{value:.2f} {unit}"
value /= 1024.0
return f"{value:.2f} EB"
def safe_display_url(url: str | None) -> str | None:
"""Return a redacted URL showing only scheme + host + port.
Strips path, query parameters, credentials, and fragments to avoid
leaking internal network topology or embedded secrets (CWE-200).
"""
if not url:
return None
try:
parsed = urlparse(url)
host = parsed.hostname or "unknown"
if parsed.port:
return f"{parsed.scheme}://{host}:{parsed.port}"
return f"{parsed.scheme}://{host}"
except ValueError:
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
return "<unparseable>"
def format_kb(k: Any) -> str:
"""Format kilobyte values into human-readable sizes.
Args:
k: Number of kilobytes, or None.
Returns:
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
"""
if k is None:
return "N/A"
try:
k = int(k)
except (ValueError, TypeError):
return "N/A"
if k >= 1024 * 1024 * 1024:
return f"{k / (1024 * 1024 * 1024):.2f} TB"
if k >= 1024 * 1024:
return f"{k / (1024 * 1024):.2f} GB"
if k >= 1024:
return f"{k / 1024:.2f} MB"
return f"{k:.2f} KB"

View File

@@ -15,11 +15,8 @@ from .config.settings import (
UNRAID_MCP_HOST, UNRAID_MCP_HOST,
UNRAID_MCP_PORT, UNRAID_MCP_PORT,
UNRAID_MCP_TRANSPORT, UNRAID_MCP_TRANSPORT,
UNRAID_VERIFY_SSL,
VERSION, VERSION,
validate_required_config,
) )
from .subscriptions.diagnostics import register_diagnostic_tools
from .subscriptions.resources import register_subscription_resources from .subscriptions.resources import register_subscription_resources
from .tools.array import register_array_tool from .tools.array import register_array_tool
from .tools.docker import register_docker_tool from .tools.docker import register_docker_tool
@@ -47,10 +44,9 @@ mcp = FastMCP(
def register_all_modules() -> None: def register_all_modules() -> None:
"""Register all tools and resources with the MCP instance.""" """Register all tools and resources with the MCP instance."""
try: try:
# Register subscription resources and diagnostic tools # Register subscription resources first
register_subscription_resources(mcp) register_subscription_resources(mcp)
register_diagnostic_tools(mcp) logger.info("Subscription resources registered")
logger.info("Subscription resources and diagnostic tools registered")
# Register all consolidated tools # Register all consolidated tools
registrars = [ registrars = [
@@ -77,15 +73,6 @@ def register_all_modules() -> None:
def run_server() -> None: def run_server() -> None:
"""Run the MCP server with the configured transport.""" """Run the MCP server with the configured transport."""
# Validate required configuration before anything else
is_valid, missing = validate_required_config()
if not is_valid:
logger.critical(
f"Missing required configuration: {', '.join(missing)}. "
"Set these environment variables or add them to your .env file."
)
sys.exit(1)
# Log configuration # Log configuration
if UNRAID_API_URL: if UNRAID_API_URL:
logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...") logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...")
@@ -101,13 +88,6 @@ def run_server() -> None:
logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}") logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}")
logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}") logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}")
if UNRAID_VERIFY_SSL is False:
logger.warning(
"SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). "
"Connections to Unraid API are vulnerable to man-in-the-middle attacks. "
"Only use this in trusted networks or for development."
)
# Register all modules # Register all modules
register_all_modules() register_all_modules()

View File

@@ -6,10 +6,8 @@ development and debugging purposes.
""" """
import asyncio import asyncio
import contextlib
import json import json
import re from datetime import datetime
from datetime import UTC, datetime
from typing import Any from typing import Any
import websockets import websockets
@@ -19,63 +17,9 @@ from websockets.typing import Subprotocol
from ..config.logging import logger from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.exceptions import ToolError from ..core.exceptions import ToolError
from ..core.utils import safe_display_url
from .manager import subscription_manager from .manager import subscription_manager
from .resources import ensure_subscriptions_started from .resources import ensure_subscriptions_started
from .utils import build_ws_ssl_context, build_ws_url from .utils import build_ws_ssl_context
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
{
"logFileSubscription",
"containerStatsSubscription",
"cpuSubscription",
"memorySubscription",
"arraySubscription",
"networkSubscription",
"dockerSubscription",
"vmSubscription",
}
)
# Pattern: must start with "subscription" and contain only a known subscription name.
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
def _validate_subscription_query(query: str) -> str:
"""Validate that a subscription query is safe to execute.
Only allows subscription operations targeting whitelisted subscription names.
Rejects any query containing mutation/query keywords.
Returns:
The extracted subscription name.
Raises:
ToolError: If the query fails validation.
"""
if _FORBIDDEN_KEYWORDS.search(query):
raise ToolError("Query rejected: must be a subscription, not a mutation or query.")
match = _SUBSCRIPTION_NAME_PATTERN.match(query)
if not match:
raise ToolError(
"Query rejected: must start with 'subscription' and contain a valid "
"subscription operation. Example: subscription { logFileSubscription { ... } }"
)
sub_name = match.group(1)
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
raise ToolError(
f"Subscription '{sub_name}' is not allowed. "
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
)
return sub_name
def register_diagnostic_tools(mcp: FastMCP) -> None: def register_diagnostic_tools(mcp: FastMCP) -> None:
@@ -90,10 +34,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"""Test a GraphQL subscription query directly to debug schema issues. """Test a GraphQL subscription query directly to debug schema issues.
Use this to find working subscription field names and structure. Use this to find working subscription field names and structure.
Only whitelisted subscriptions are allowed (logFileSubscription,
containerStatsSubscription, cpuSubscription, memorySubscription,
arraySubscription, networkSubscription, dockerSubscription,
vmSubscription).
Args: Args:
subscription_query: The GraphQL subscription query to test subscription_query: The GraphQL subscription query to test
@@ -101,16 +41,16 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
Returns: Returns:
Dict containing test results and response data Dict containing test results and response data
""" """
# Validate before any network I/O
sub_name = _validate_subscription_query(subscription_query)
try: try:
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'") logger.info(f"[TEST_SUBSCRIPTION] Testing query: {subscription_query}")
try: # Build WebSocket URL
ws_url = build_ws_url() if not UNRAID_API_URL:
except ValueError as e: raise ToolError("UNRAID_API_URL is not configured")
raise ToolError(str(e)) from e ws_url = (
UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://")
+ "/graphql"
)
ssl_context = build_ws_ssl_context(ws_url) ssl_context = build_ws_ssl_context(ws_url)
@@ -119,7 +59,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
ws_url, ws_url,
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")], subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
ssl=ssl_context, ssl=ssl_context,
open_timeout=10,
ping_interval=30, ping_interval=30,
ping_timeout=10, ping_timeout=10,
) as websocket: ) as websocket:
@@ -163,8 +102,6 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
"note": "Connection successful, subscription may be waiting for events", "note": "Connection successful, subscription may be waiting for events",
} }
except ToolError:
raise
except Exception as e: except Exception as e:
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True) logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
return {"error": str(e), "query_tested": subscription_query} return {"error": str(e), "query_tested": subscription_query}
@@ -185,18 +122,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
logger.info("[DIAGNOSTIC] Running subscription diagnostics...") logger.info("[DIAGNOSTIC] Running subscription diagnostics...")
# Get comprehensive status # Get comprehensive status
status = await subscription_manager.get_subscription_status() status = subscription_manager.get_subscription_status()
# Initialize connection issues list with proper type # Initialize connection issues list with proper type
connection_issues: list[dict[str, Any]] = [] connection_issues: list[dict[str, Any]] = []
# Add environment info with explicit typing # Add environment info with explicit typing
diagnostic_info: dict[str, Any] = { diagnostic_info: dict[str, Any] = {
"timestamp": datetime.now(UTC).isoformat(), "timestamp": datetime.now().isoformat(),
"environment": { "environment": {
"auto_start_enabled": subscription_manager.auto_start_enabled, "auto_start_enabled": subscription_manager.auto_start_enabled,
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts, "max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
"unraid_api_url": safe_display_url(UNRAID_API_URL), "unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None,
"api_key_configured": bool(UNRAID_API_KEY), "api_key_configured": bool(UNRAID_API_KEY),
"websocket_url": None, "websocket_url": None,
}, },
@@ -215,9 +152,17 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
}, },
} }
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured) # Calculate WebSocket URL
with contextlib.suppress(ValueError): if UNRAID_API_URL:
diagnostic_info["environment"]["websocket_url"] = build_ws_url() if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
diagnostic_info["environment"]["websocket_url"] = ws_url
# Analyze issues # Analyze issues
for sub_name, sub_status in status.items(): for sub_name, sub_status in status.items():

View File

@@ -8,74 +8,16 @@ error handling, reconnection logic, and authentication.
import asyncio import asyncio
import json import json
import os import os
import time from datetime import datetime
from datetime import UTC, datetime
from typing import Any from typing import Any
import websockets import websockets
from websockets.typing import Subprotocol from websockets.typing import Subprotocol
from ..config.logging import logger from ..config.logging import logger
from ..config.settings import UNRAID_API_KEY from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
from ..core.client import redact_sensitive
from ..core.types import SubscriptionData from ..core.types import SubscriptionData
from .utils import build_ws_ssl_context, build_ws_url from .utils import build_ws_ssl_context
# Resource data size limits to prevent unbounded memory growth
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1MB
_MAX_RESOURCE_DATA_LINES = 5_000
# Minimum stable connection duration (seconds) before resetting reconnect counter
_STABLE_CONNECTION_SECONDS = 30
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
"""Cap log content in subscription data to prevent unbounded memory growth.
Returns a new dict — does NOT mutate the input. If any nested 'content'
field (from log subscriptions) exceeds the byte limit, truncate it to the
most recent _MAX_RESOURCE_DATA_LINES lines.
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = _cap_log_content(value)
elif (
key == "content"
and isinstance(value, str)
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
):
lines = value.splitlines()
original_line_count = len(lines)
# Keep most recent lines first.
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
# Enforce byte cap while preserving whole-line boundaries where possible.
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
lines = lines[1:]
truncated = "\n".join(lines)
truncated_bytes = truncated.encode("utf-8", errors="replace")
# Last resort: if a single line still exceeds cap, hard-cap bytes.
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
"utf-8", errors="ignore"
)
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
)
result[key] = truncated
else:
result[key] = value
return result
class SubscriptionManager: class SubscriptionManager:
@@ -84,6 +26,7 @@ class SubscriptionManager:
def __init__(self) -> None: def __init__(self) -> None:
self.active_subscriptions: dict[str, asyncio.Task[None]] = {} self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
self.resource_data: dict[str, SubscriptionData] = {} self.resource_data: dict[str, SubscriptionData] = {}
self.websocket: websockets.WebSocketServerProtocol | None = None
self.subscription_lock = asyncio.Lock() self.subscription_lock = asyncio.Lock()
# Configuration # Configuration
@@ -94,7 +37,6 @@ class SubscriptionManager:
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10")) self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
self.connection_states: dict[str, str] = {} # Track connection state per subscription self.connection_states: dict[str, str] = {} # Track connection state per subscription
self.last_error: dict[str, str] = {} # Track last error per subscription self.last_error: dict[str, str] = {} # Track last error per subscription
self._connection_start_times: dict[str, float] = {} # Track when connections started
# Define subscription configurations # Define subscription configurations
self.subscription_configs = { self.subscription_configs = {
@@ -163,7 +105,6 @@ class SubscriptionManager:
# Reset connection tracking # Reset connection tracking
self.reconnect_attempts[subscription_name] = 0 self.reconnect_attempts[subscription_name] = 0
self.connection_states[subscription_name] = "starting" self.connection_states[subscription_name] = "starting"
self._connection_start_times.pop(subscription_name, None)
async with self.subscription_lock: async with self.subscription_lock:
try: try:
@@ -197,7 +138,6 @@ class SubscriptionManager:
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully") logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
del self.active_subscriptions[subscription_name] del self.active_subscriptions[subscription_name]
self.connection_states[subscription_name] = "stopped" self.connection_states[subscription_name] = "stopped"
self._connection_start_times.pop(subscription_name, None)
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped") logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
else: else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop") logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
@@ -225,7 +165,20 @@ class SubscriptionManager:
break break
try: try:
ws_url = build_ws_url() # Build WebSocket URL with detailed logging
if not UNRAID_API_URL:
raise ValueError("UNRAID_API_URL is not configured")
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}") logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
logger.debug( logger.debug(
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}" f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}"
@@ -242,7 +195,6 @@ class SubscriptionManager:
async with websockets.connect( async with websockets.connect(
ws_url, ws_url,
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")], subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
open_timeout=connect_timeout,
ping_interval=20, ping_interval=20,
ping_timeout=10, ping_timeout=10,
close_timeout=10, close_timeout=10,
@@ -254,9 +206,9 @@ class SubscriptionManager:
) )
self.connection_states[subscription_name] = "connected" self.connection_states[subscription_name] = "connected"
# Track connection start time — only reset retry counter # Reset retry count on successful connection
# after the connection proves stable (>30s connected) self.reconnect_attempts[subscription_name] = 0
self._connection_start_times[subscription_name] = time.monotonic() retry_delay = 5 # Reset delay
# Initialize GraphQL-WS protocol # Initialize GraphQL-WS protocol
logger.debug( logger.debug(
@@ -338,9 +290,7 @@ class SubscriptionManager:
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}" f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
) )
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...") logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
logger.debug( logger.debug(f"[SUBSCRIPTION:{subscription_name}] Variables: {variables}")
f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}"
)
await websocket.send(json.dumps(subscription_message)) await websocket.send(json.dumps(subscription_message))
logger.info( logger.info(
@@ -376,18 +326,11 @@ class SubscriptionManager:
logger.info( logger.info(
f"[DATA:{subscription_name}] Received subscription data update" f"[DATA:{subscription_name}] Received subscription data update"
) )
capped_data = ( self.resource_data[subscription_name] = SubscriptionData(
_cap_log_content(payload["data"]) data=payload["data"],
if isinstance(payload["data"], dict) last_updated=datetime.now(),
else payload["data"]
)
new_entry = SubscriptionData(
data=capped_data,
last_updated=datetime.now(UTC),
subscription_type=subscription_name, subscription_type=subscription_name,
) )
async with self.subscription_lock:
self.resource_data[subscription_name] = new_entry
logger.debug( logger.debug(
f"[RESOURCE:{subscription_name}] Resource data updated successfully" f"[RESOURCE:{subscription_name}] Resource data updated successfully"
) )
@@ -448,8 +391,7 @@ class SubscriptionManager:
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}") logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[DATA:{subscription_name}] Error processing message: {e}", f"[DATA:{subscription_name}] Error processing message: {e}"
exc_info=True,
) )
msg_preview = ( msg_preview = (
message[:200] message[:200]
@@ -479,39 +421,11 @@ class SubscriptionManager:
self.connection_states[subscription_name] = "invalid_uri" self.connection_states[subscription_name] = "invalid_uri"
break # Don't retry on invalid URI break # Don't retry on invalid URI
except ValueError as e: except Exception as e:
# Non-retryable configuration error (e.g. UNRAID_API_URL not set) error_msg = f"Unexpected error: {e}"
error_msg = f"Configuration error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}") logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error" self.connection_states[subscription_name] = "error"
break # Don't retry on configuration errors
except Exception as e:
error_msg = f"Unexpected error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True)
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
# Check if connection was stable before deciding on retry behavior
start_time = self._connection_start_times.pop(subscription_name, None)
if start_time is not None:
connected_duration = time.monotonic() - start_time
if connected_duration >= _STABLE_CONNECTION_SECONDS:
# Connection was stable — reset retry counter and backoff
logger.info(
f"[WEBSOCKET:{subscription_name}] Connection was stable "
f"({connected_duration:.0f}s >= {_STABLE_CONNECTION_SECONDS}s), "
f"resetting retry counter"
)
self.reconnect_attempts[subscription_name] = 0
retry_delay = 5
else:
logger.warning(
f"[WEBSOCKET:{subscription_name}] Connection was unstable "
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
)
# Calculate backoff delay # Calculate backoff delay
retry_delay = min(retry_delay * 1.5, max_retry_delay) retry_delay = min(retry_delay * 1.5, max_retry_delay)
@@ -521,24 +435,13 @@ class SubscriptionManager:
self.connection_states[subscription_name] = "reconnecting" self.connection_states[subscription_name] = "reconnecting"
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
# The while loop exited (via break or max_retries exceeded). def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
# Remove from active_subscriptions so start_subscription() can restart it.
async with self.subscription_lock:
self.active_subscriptions.pop(subscription_name, None)
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
f"removed from active_subscriptions. Final state: "
f"{self.connection_states.get(subscription_name, 'unknown')}"
)
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
"""Get current resource data with enhanced logging.""" """Get current resource data with enhanced logging."""
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested") logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
async with self.subscription_lock:
if resource_name in self.resource_data: if resource_name in self.resource_data:
data = self.resource_data[resource_name] data = self.resource_data[resource_name]
age_seconds = (datetime.now(UTC) - data.last_updated).total_seconds() age_seconds = (datetime.now() - data.last_updated).total_seconds()
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s") logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
return data.data return data.data
logger.debug(f"[RESOURCE:{resource_name}] No data available") logger.debug(f"[RESOURCE:{resource_name}] No data available")
@@ -550,11 +453,10 @@ class SubscriptionManager:
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}") logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
return active return active
async def get_subscription_status(self) -> dict[str, dict[str, Any]]: def get_subscription_status(self) -> dict[str, dict[str, Any]]:
"""Get detailed status of all subscriptions for diagnostics.""" """Get detailed status of all subscriptions for diagnostics."""
status = {} status = {}
async with self.subscription_lock:
for sub_name, config in self.subscription_configs.items(): for sub_name, config in self.subscription_configs.items():
sub_status = { sub_status = {
"config": { "config": {
@@ -573,7 +475,7 @@ class SubscriptionManager:
# Add data info if available # Add data info if available
if sub_name in self.resource_data: if sub_name in self.resource_data:
data_info = self.resource_data[sub_name] data_info = self.resource_data[sub_name]
age_seconds = (datetime.now(UTC) - data_info.last_updated).total_seconds() age_seconds = (datetime.now() - data_info.last_updated).total_seconds()
sub_status["data"] = { sub_status["data"] = {
"available": True, "available": True,
"last_updated": data_info.last_updated.isoformat(), "last_updated": data_info.last_updated.isoformat(),

View File

@@ -44,7 +44,6 @@ async def autostart_subscriptions() -> None:
logger.info("[AUTOSTART] Auto-start process completed successfully") logger.info("[AUTOSTART] Auto-start process completed successfully")
except Exception as e: except Exception as e:
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True) logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
raise # Propagate so ensure_subscriptions_started doesn't mark as started
# Optional log file subscription # Optional log file subscription
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH") log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")
@@ -83,7 +82,7 @@ def register_subscription_resources(mcp: FastMCP) -> None:
async def logs_stream_resource() -> str: async def logs_stream_resource() -> str:
"""Real-time log stream data from subscription.""" """Real-time log stream data from subscription."""
await ensure_subscriptions_started() await ensure_subscriptions_started()
data = await subscription_manager.get_resource_data("logFileSubscription") data = subscription_manager.get_resource_data("logFileSubscription")
if data: if data:
return json.dumps(data, indent=2) return json.dumps(data, indent=2)
return json.dumps( return json.dumps(

View File

@@ -2,34 +2,7 @@
import ssl as _ssl import ssl as _ssl
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL from ..config.settings import UNRAID_VERIFY_SSL
def build_ws_url() -> str:
"""Build a WebSocket URL from the configured UNRAID_API_URL.
Converts http(s) scheme to ws(s) and ensures /graphql path suffix.
Returns:
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
Raises:
ValueError: If UNRAID_API_URL is not configured.
"""
if not UNRAID_API_URL:
raise ValueError("UNRAID_API_URL is not configured")
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
else:
ws_url = UNRAID_API_URL
if not ws_url.endswith("/graphql"):
ws_url = ws_url.rstrip("/") + "/graphql"
return ws_url
def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None: def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:

View File

@@ -3,13 +3,13 @@
Provides the `unraid_array` tool with 5 actions for parity check management. Provides the `unraid_array` tool with 5 actions for parity check management.
""" """
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -53,14 +53,6 @@ ARRAY_ACTIONS = Literal[
"parity_status", "parity_status",
] ]
if set(get_args(ARRAY_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(ARRAY_ACTIONS))
_extra = set(get_args(ARRAY_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"ARRAY_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_array_tool(mcp: FastMCP) -> None: def register_array_tool(mcp: FastMCP) -> None:
"""Register the unraid_array tool with the FastMCP instance.""" """Register the unraid_array tool with the FastMCP instance."""
@@ -82,7 +74,7 @@ def register_array_tool(mcp: FastMCP) -> None:
if action not in ALL_ACTIONS: if action not in ALL_ACTIONS:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
with tool_error_handler("array", action, logger): try:
logger.info(f"Executing unraid_array action={action}") logger.info(f"Executing unraid_array action={action}")
if action in QUERIES: if action in QUERIES:
@@ -103,4 +95,10 @@ def register_array_tool(mcp: FastMCP) -> None:
"data": data, "data": data,
} }
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_array action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute array/{action}: {e!s}") from e
logger.info("Array tool registered successfully") logger.info("Array tool registered successfully")

View File

@@ -5,14 +5,13 @@ logs, networks, and update management.
""" """
import re import re
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
from ..core.utils import safe_get
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -99,10 +98,7 @@ MUTATIONS: dict[str, str] = {
""", """,
} }
DESTRUCTIVE_ACTIONS = {"remove", "update_all"} DESTRUCTIVE_ACTIONS = {"remove"}
# NOTE (Code-M-07): "details" and "logs" are listed here because they require a
# container_id parameter, but unlike mutations they use fuzzy name matching (not
# strict). This is intentional: read-only queries are safe with fuzzy matching.
_ACTIONS_REQUIRING_CONTAINER_ID = { _ACTIONS_REQUIRING_CONTAINER_ID = {
"start", "start",
"stop", "stop",
@@ -115,7 +111,6 @@ _ACTIONS_REQUIRING_CONTAINER_ID = {
"logs", "logs",
} }
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"restart"} ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"restart"}
_MAX_TAIL_LINES = 10_000
DOCKER_ACTIONS = Literal[ DOCKER_ACTIONS = Literal[
"list", "list",
@@ -135,36 +130,33 @@ DOCKER_ACTIONS = Literal[
"check_updates", "check_updates",
] ]
if set(get_args(DOCKER_ACTIONS)) != ALL_ACTIONS: # Docker container IDs: 64 hex chars + optional suffix (e.g., ":local")
_missing = ALL_ACTIONS - set(get_args(DOCKER_ACTIONS))
_extra = set(get_args(DOCKER_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"DOCKER_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
# Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local")
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE) _DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
# Short hex prefix: at least 12 hex chars (standard Docker short ID length)
_DOCKER_SHORT_ID_PATTERN = re.compile(r"^[a-f0-9]{12,63}$", re.IGNORECASE) def _safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
"""Safely traverse nested dict keys, handling None intermediates."""
current = data
for key in keys:
if not isinstance(current, dict):
return default
current = current.get(key)
return current if current is not None else default
def find_container_by_identifier( def find_container_by_identifier(
identifier: str, containers: list[dict[str, Any]], *, strict: bool = False identifier: str, containers: list[dict[str, Any]]
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Find a container by ID or name with optional fuzzy matching. """Find a container by ID or name with fuzzy matching.
Match priority: Match priority:
1. Exact ID match 1. Exact ID match
2. Exact name match (case-sensitive) 2. Exact name match (case-sensitive)
When strict=False (default), also tries:
3. Name starts with identifier (case-insensitive) 3. Name starts with identifier (case-insensitive)
4. Name contains identifier as substring (case-insensitive) 4. Name contains identifier as substring (case-insensitive)
When strict=True, only exact matches (1 & 2) are used. Note: Short identifiers (e.g. "db") may match unintended containers
Use strict=True for mutations to prevent targeting the wrong container. via substring. Use more specific names or IDs for precision.
""" """
if not containers: if not containers:
return None return None
@@ -176,24 +168,20 @@ def find_container_by_identifier(
if identifier in c.get("names", []): if identifier in c.get("names", []):
return c return c
# Strict mode: no fuzzy matching allowed
if strict:
return None
id_lower = identifier.lower() id_lower = identifier.lower()
# Priority 3: prefix match (more precise than substring) # Priority 3: prefix match (more precise than substring)
for c in containers: for c in containers:
for name in c.get("names", []): for name in c.get("names", []):
if name.lower().startswith(id_lower): if name.lower().startswith(id_lower):
logger.debug(f"Prefix match: '{identifier}' -> '{name}'") logger.info(f"Prefix match: '{identifier}' -> '{name}'")
return c return c
# Priority 4: substring match (least precise) # Priority 4: substring match (least precise)
for c in containers: for c in containers:
for name in c.get("names", []): for name in c.get("names", []):
if id_lower in name.lower(): if id_lower in name.lower():
logger.debug(f"Substring match: '{identifier}' -> '{name}'") logger.info(f"Substring match: '{identifier}' -> '{name}'")
return c return c
return None return None
@@ -207,65 +195,26 @@ def get_available_container_names(containers: list[dict[str, Any]]) -> list[str]
return names return names
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str: async def _resolve_container_id(container_id: str) -> str:
"""Resolve a container name/identifier to its actual PrefixedID. """Resolve a container name/identifier to its actual PrefixedID."""
Optimization: if the identifier is a full 64-char hex ID (with optional
:suffix), skip the container list fetch entirely and use it directly.
If it's a short hex prefix (12-63 chars), fetch the list and match by
ID prefix. Only fetch the container list for name-based lookups.
Args:
container_id: Container name or ID to resolve
strict: When True, only exact name/ID matches are allowed (no fuzzy).
Use for mutations to prevent targeting the wrong container.
"""
# Full PrefixedID: skip the list fetch entirely
if _DOCKER_ID_PATTERN.match(container_id): if _DOCKER_ID_PATTERN.match(container_id):
return container_id return container_id
logger.info(f"Resolving container identifier '{container_id}' (strict={strict})") logger.info(f"Resolving container identifier '{container_id}'")
list_query = """ list_query = """
query ResolveContainerID { query ResolveContainerID {
docker { containers(skipCache: true) { id names } } docker { containers(skipCache: true) { id names } }
} }
""" """
data = await make_graphql_request(list_query) data = await make_graphql_request(list_query)
containers = safe_get(data, "docker", "containers", default=[]) containers = _safe_get(data, "docker", "containers", default=[])
resolved = find_container_by_identifier(container_id, containers)
# Short hex prefix: match by ID prefix before trying name matching
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
id_lower = container_id.lower()
matches: list[dict[str, Any]] = []
for c in containers:
cid = (c.get("id") or "").lower()
if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower):
matches.append(c)
if len(matches) == 1:
actual_id = str(matches[0].get("id", ""))
logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'")
return actual_id
if len(matches) > 1:
candidate_ids = [str(c.get("id", "")) for c in matches[:5]]
raise ToolError(
f"Short container ID prefix '{container_id}' is ambiguous. "
f"Matches: {', '.join(candidate_ids)}. Use a longer ID or exact name."
)
resolved = find_container_by_identifier(container_id, containers, strict=strict)
if resolved: if resolved:
actual_id = str(resolved.get("id", "")) actual_id = str(resolved.get("id", ""))
logger.info(f"Resolved '{container_id}' -> '{actual_id}'") logger.info(f"Resolved '{container_id}' -> '{actual_id}'")
return actual_id return actual_id
available = get_available_container_names(containers) available = get_available_container_names(containers)
if strict:
msg = (
f"Container '{container_id}' not found by exact match. "
f"Mutations require an exact container name or full ID — "
f"fuzzy/substring matching is not allowed for safety."
)
else:
msg = f"Container '{container_id}' not found." msg = f"Container '{container_id}' not found."
if available: if available:
msg += f" Available: {', '.join(available[:10])}" msg += f" Available: {', '.join(available[:10])}"
@@ -315,58 +264,56 @@ def register_docker_tool(mcp: FastMCP) -> None:
if action == "network_details" and not network_id: if action == "network_details" and not network_id:
raise ToolError("network_id is required for 'network_details' action") raise ToolError("network_id is required for 'network_details' action")
if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES): try:
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
with tool_error_handler("docker", action, logger):
logger.info(f"Executing unraid_docker action={action}") logger.info(f"Executing unraid_docker action={action}")
# --- Read-only queries --- # --- Read-only queries ---
if action == "list": if action == "list":
data = await make_graphql_request(QUERIES["list"]) data = await make_graphql_request(QUERIES["list"])
containers = safe_get(data, "docker", "containers", default=[]) containers = _safe_get(data, "docker", "containers", default=[])
return {"containers": containers} return {"containers": list(containers) if isinstance(containers, list) else []}
if action == "details": if action == "details":
# Resolve name -> ID first (skips list fetch if already an ID)
actual_id = await _resolve_container_id(container_id or "")
data = await make_graphql_request(QUERIES["details"]) data = await make_graphql_request(QUERIES["details"])
containers = safe_get(data, "docker", "containers", default=[]) containers = _safe_get(data, "docker", "containers", default=[])
# Match by resolved ID (exact match, no second list fetch needed) container = find_container_by_identifier(container_id or "", containers)
for c in containers: if container:
if c.get("id") == actual_id: return container
return c available = get_available_container_names(containers)
raise ToolError(f"Container '{container_id}' not found in details response.") msg = f"Container '{container_id}' not found."
if available:
msg += f" Available: {', '.join(available[:10])}"
raise ToolError(msg)
if action == "logs": if action == "logs":
actual_id = await _resolve_container_id(container_id or "") actual_id = await _resolve_container_id(container_id or "")
data = await make_graphql_request( data = await make_graphql_request(
QUERIES["logs"], {"id": actual_id, "tail": tail_lines} QUERIES["logs"], {"id": actual_id, "tail": tail_lines}
) )
return {"logs": safe_get(data, "docker", "logs")} return {"logs": _safe_get(data, "docker", "logs")}
if action == "networks": if action == "networks":
data = await make_graphql_request(QUERIES["networks"]) data = await make_graphql_request(QUERIES["networks"])
networks = safe_get(data, "dockerNetworks", default=[]) networks = data.get("dockerNetworks", [])
return {"networks": networks} return {"networks": list(networks) if isinstance(networks, list) else []}
if action == "network_details": if action == "network_details":
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id}) data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
return dict(safe_get(data, "dockerNetwork", default={}) or {}) return dict(data.get("dockerNetwork") or {})
if action == "port_conflicts": if action == "port_conflicts":
data = await make_graphql_request(QUERIES["port_conflicts"]) data = await make_graphql_request(QUERIES["port_conflicts"])
conflicts = safe_get(data, "docker", "portConflicts", default=[]) conflicts = _safe_get(data, "docker", "portConflicts", default=[])
return {"port_conflicts": conflicts} return {"port_conflicts": list(conflicts) if isinstance(conflicts, list) else []}
if action == "check_updates": if action == "check_updates":
data = await make_graphql_request(QUERIES["check_updates"]) data = await make_graphql_request(QUERIES["check_updates"])
statuses = safe_get(data, "docker", "containerUpdateStatuses", default=[]) statuses = _safe_get(data, "docker", "containerUpdateStatuses", default=[])
return {"update_statuses": statuses} return {"update_statuses": list(statuses) if isinstance(statuses, list) else []}
# --- Mutations (strict matching: no fuzzy/substring) --- # --- Mutations ---
if action == "restart": if action == "restart":
actual_id = await _resolve_container_id(container_id or "", strict=True) actual_id = await _resolve_container_id(container_id or "")
# Stop (idempotent: treat "already stopped" as success) # Stop (idempotent: treat "already stopped" as success)
stop_data = await make_graphql_request( stop_data = await make_graphql_request(
MUTATIONS["stop"], MUTATIONS["stop"],
@@ -383,7 +330,7 @@ def register_docker_tool(mcp: FastMCP) -> None:
if start_data.get("idempotent_success"): if start_data.get("idempotent_success"):
result = {} result = {}
else: else:
result = safe_get(start_data, "docker", "start", default={}) result = _safe_get(start_data, "docker", "start", default={})
response: dict[str, Any] = { response: dict[str, Any] = {
"success": True, "success": True,
"action": "restart", "action": "restart",
@@ -395,12 +342,12 @@ def register_docker_tool(mcp: FastMCP) -> None:
if action == "update_all": if action == "update_all":
data = await make_graphql_request(MUTATIONS["update_all"]) data = await make_graphql_request(MUTATIONS["update_all"])
results = safe_get(data, "docker", "updateAllContainers", default=[]) results = _safe_get(data, "docker", "updateAllContainers", default=[])
return {"success": True, "action": "update_all", "containers": results} return {"success": True, "action": "update_all", "containers": results}
# Single-container mutations # Single-container mutations
if action in MUTATIONS: if action in MUTATIONS:
actual_id = await _resolve_container_id(container_id or "", strict=True) actual_id = await _resolve_container_id(container_id or "")
op_context: dict[str, str] | None = ( op_context: dict[str, str] | None = (
{"operation": action} if action in ("start", "stop") else None {"operation": action} if action in ("start", "stop") else None
) )
@@ -435,4 +382,10 @@ def register_docker_tool(mcp: FastMCP) -> None:
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_docker action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute docker/{action}: {e!s}") from e
logger.info("Docker tool registered successfully") logger.info("Docker tool registered successfully")

View File

@@ -6,7 +6,7 @@ connection testing, and subscription diagnostics.
import datetime import datetime
import time import time
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -19,22 +19,11 @@ from ..config.settings import (
VERSION, VERSION,
) )
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
from ..core.utils import safe_display_url
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"] HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
if set(get_args(HEALTH_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(HEALTH_ACTIONS))
_extra = set(get_args(HEALTH_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
"HEALTH_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing in HEALTH_ACTIONS: {_missing}; extra in HEALTH_ACTIONS: {_extra}"
)
# Severity ordering: only upgrade, never downgrade # Severity ordering: only upgrade, never downgrade
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3} _SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
@@ -64,10 +53,12 @@ def register_health_tool(mcp: FastMCP) -> None:
test_connection - Quick connectivity test (just checks { online }) test_connection - Quick connectivity test (just checks { online })
diagnose - Subscription system diagnostics diagnose - Subscription system diagnostics
""" """
if action not in ALL_ACTIONS: if action not in ("check", "test_connection", "diagnose"):
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") raise ToolError(
f"Invalid action '{action}'. Must be one of: check, test_connection, diagnose"
)
with tool_error_handler("health", action, logger): try:
logger.info(f"Executing unraid_health action={action}") logger.info(f"Executing unraid_health action={action}")
if action == "test_connection": if action == "test_connection":
@@ -88,6 +79,12 @@ def register_health_tool(mcp: FastMCP) -> None:
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_health action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute health/{action}: {e!s}") from e
logger.info("Health tool registered successfully") logger.info("Health tool registered successfully")
@@ -114,7 +111,7 @@ async def _comprehensive_check() -> dict[str, Any]:
overview { unread { alert warning total } } overview { unread { alert warning total } }
} }
docker { docker {
containers { id state status } containers(skipCache: true) { id state status }
} }
} }
""" """
@@ -138,7 +135,7 @@ async def _comprehensive_check() -> dict[str, Any]:
if info: if info:
health_info["unraid_system"] = { health_info["unraid_system"] = {
"status": "connected", "status": "connected",
"url": safe_display_url(UNRAID_API_URL), "url": UNRAID_API_URL,
"machine_id": info.get("machineId"), "machine_id": info.get("machineId"),
"version": info.get("versions", {}).get("unraid"), "version": info.get("versions", {}).get("unraid"),
"uptime": info.get("os", {}).get("uptime"), "uptime": info.get("os", {}).get("uptime"),
@@ -209,7 +206,7 @@ async def _comprehensive_check() -> dict[str, Any]:
except Exception as e: except Exception as e:
# Intentionally broad: health checks must always return a result, # Intentionally broad: health checks must always return a result,
# even on unexpected failures, so callers never get an unhandled exception. # even on unexpected failures, so callers never get an unhandled exception.
logger.error(f"Health check failed: {e}", exc_info=True) logger.error(f"Health check failed: {e}")
return { return {
"status": "unhealthy", "status": "unhealthy",
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(), "timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
@@ -218,42 +215,6 @@ async def _comprehensive_check() -> dict[str, Any]:
} }
def _analyze_subscription_status(
status: dict[str, Any],
) -> tuple[int, list[dict[str, Any]]]:
"""Analyze subscription status dict, returning error count and connection issues.
This is the canonical implementation of subscription status analysis.
TODO: subscriptions/diagnostics.py has a similar status-analysis pattern
in diagnose_subscriptions(). That module could import and call this helper
directly to avoid divergence. See Code-H05.
Args:
status: Dict of subscription name -> status info from get_subscription_status().
Returns:
Tuple of (error_count, connection_issues_list).
"""
error_count = 0
connection_issues: list[dict[str, Any]] = []
for sub_name, sub_status in status.items():
runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
error_count += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": conn_state,
"error": runtime["last_error"],
}
)
return error_count, connection_issues
async def _diagnose_subscriptions() -> dict[str, Any]: async def _diagnose_subscriptions() -> dict[str, Any]:
"""Import and run subscription diagnostics.""" """Import and run subscription diagnostics."""
try: try:
@@ -262,10 +223,13 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
await ensure_subscriptions_started() await ensure_subscriptions_started()
status = await subscription_manager.get_subscription_status() status = subscription_manager.get_subscription_status()
error_count, connection_issues = _analyze_subscription_status(status) # This list is intentionally placed into the summary dict below and then
# appended to in the loop — the mutable alias ensures both references
# reflect the same data without a second pass.
connection_issues: list[dict[str, Any]] = []
return { diagnostic_info: dict[str, Any] = {
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(), "timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
"environment": { "environment": {
"auto_start_enabled": subscription_manager.auto_start_enabled, "auto_start_enabled": subscription_manager.auto_start_enabled,
@@ -277,12 +241,31 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
"total_configured": len(subscription_manager.subscription_configs), "total_configured": len(subscription_manager.subscription_configs),
"active_count": len(subscription_manager.active_subscriptions), "active_count": len(subscription_manager.active_subscriptions),
"with_data": len(subscription_manager.resource_data), "with_data": len(subscription_manager.resource_data),
"in_error_state": error_count, "in_error_state": 0,
"connection_issues": connection_issues, "connection_issues": connection_issues,
}, },
} }
except ImportError as e: for sub_name, sub_status in status.items():
raise ToolError("Subscription modules not available") from e runtime = sub_status.get("runtime", {})
conn_state = runtime.get("connection_state", "unknown")
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
diagnostic_info["summary"]["in_error_state"] += 1
if runtime.get("last_error"):
connection_issues.append(
{
"subscription": sub_name,
"state": conn_state,
"error": runtime["last_error"],
}
)
return diagnostic_info
except ImportError:
return {
"error": "Subscription modules not available",
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
}
except Exception as e: except Exception as e:
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e raise ToolError(f"Failed to generate diagnostics: {e!s}") from e

View File

@@ -4,14 +4,13 @@ Provides the `unraid_info` tool with 19 read-only actions for retrieving
system information, array status, network config, and server metadata. system information, array status, network config, and server metadata.
""" """
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
from ..core.utils import format_kb
# Pre-built queries keyed by action name # Pre-built queries keyed by action name
@@ -20,7 +19,7 @@ QUERIES: dict[str, str] = {
query GetSystemInfo { query GetSystemInfo {
info { info {
os { platform distro release codename kernel arch hostname codepage logofile serial build uptime } os { platform distro release codename kernel arch hostname codepage logofile serial build uptime }
cpu { manufacturer brand vendor family model stepping revision voltage speed speedmin speedmax threads cores processors socket cache } cpu { manufacturer brand vendor family model stepping revision voltage speed speedmin speedmax threads cores processors socket cache flags }
memory { memory {
layout { bank type clockSpeed formFactor manufacturer partNum serialNum } layout { bank type clockSpeed formFactor manufacturer partNum serialNum }
} }
@@ -82,6 +81,7 @@ QUERIES: dict[str, str] = {
shareAvahiEnabled safeMode startMode configValid configError joinStatus shareAvahiEnabled safeMode startMode configValid configError joinStatus
deviceCount flashGuid flashProduct flashVendor mdState mdVersion deviceCount flashGuid flashProduct flashVendor mdState mdVersion
shareCount shareSmbCount shareNfsCount shareAfpCount shareMoverActive shareCount shareSmbCount shareNfsCount shareAfpCount shareMoverActive
csrfToken
} }
} }
""", """,
@@ -156,8 +156,6 @@ QUERIES: dict[str, str] = {
""", """,
} }
ALL_ACTIONS = set(QUERIES)
INFO_ACTIONS = Literal[ INFO_ACTIONS = Literal[
"overview", "overview",
"array", "array",
@@ -180,12 +178,8 @@ INFO_ACTIONS = Literal[
"ups_config", "ups_config",
] ]
if set(get_args(INFO_ACTIONS)) != ALL_ACTIONS: assert set(QUERIES.keys()) == set(INFO_ACTIONS.__args__), (
_missing = ALL_ACTIONS - set(get_args(INFO_ACTIONS)) "QUERIES keys and INFO_ACTIONS are out of sync"
_extra = set(get_args(INFO_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"QUERIES keys and INFO_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
) )
@@ -195,17 +189,17 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
if raw_info.get("os"): if raw_info.get("os"):
os_info = raw_info["os"] os_info = raw_info["os"]
summary["os"] = ( summary["os"] = (
f"{os_info.get('distro') or 'unknown'} {os_info.get('release') or 'unknown'} " f"{os_info.get('distro', '')} {os_info.get('release', '')} "
f"({os_info.get('platform') or 'unknown'}, {os_info.get('arch') or 'unknown'})" f"({os_info.get('platform', '')}, {os_info.get('arch', '')})"
) )
summary["hostname"] = os_info.get("hostname") or "unknown" summary["hostname"] = os_info.get("hostname")
summary["uptime"] = os_info.get("uptime") summary["uptime"] = os_info.get("uptime")
if raw_info.get("cpu"): if raw_info.get("cpu"):
cpu = raw_info["cpu"] cpu = raw_info["cpu"]
summary["cpu"] = ( summary["cpu"] = (
f"{cpu.get('manufacturer') or 'unknown'} {cpu.get('brand') or 'unknown'} " f"{cpu.get('manufacturer', '')} {cpu.get('brand', '')} "
f"({cpu.get('cores') or '?'} cores, {cpu.get('threads') or '?'} threads)" f"({cpu.get('cores', '?')} cores, {cpu.get('threads', '?')} threads)"
) )
if raw_info.get("memory") and raw_info["memory"].get("layout"): if raw_info.get("memory") and raw_info["memory"].get("layout"):
@@ -213,10 +207,10 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
summary["memory_layout_details"] = [] summary["memory_layout_details"] = []
for stick in mem_layout: for stick in mem_layout:
summary["memory_layout_details"].append( summary["memory_layout_details"].append(
f"Bank {stick.get('bank') or '?'}: Type {stick.get('type') or '?'}, " f"Bank {stick.get('bank', '?')}: Type {stick.get('type', '?')}, "
f"Speed {stick.get('clockSpeed') or '?'}MHz, " f"Speed {stick.get('clockSpeed', '?')}MHz, "
f"Manufacturer: {stick.get('manufacturer') or '?'}, " f"Manufacturer: {stick.get('manufacturer', '?')}, "
f"Part: {stick.get('partNum') or '?'}" f"Part: {stick.get('partNum', '?')}"
) )
summary["memory_summary"] = ( summary["memory_summary"] = (
"Stick layout details retrieved. Overall total/used/free memory stats " "Stick layout details retrieved. Overall total/used/free memory stats "
@@ -261,14 +255,31 @@ def _analyze_disk_health(disks: list[dict[str, Any]]) -> dict[str, int]:
return counts return counts
def _format_kb(k: Any) -> str:
"""Format kilobyte values into human-readable sizes."""
if k is None:
return "N/A"
try:
k = int(k)
except (ValueError, TypeError):
return "N/A"
if k >= 1024 * 1024 * 1024:
return f"{k / (1024 * 1024 * 1024):.2f} TB"
if k >= 1024 * 1024:
return f"{k / (1024 * 1024):.2f} GB"
if k >= 1024:
return f"{k / 1024:.2f} MB"
return f"{k} KB"
def _process_array_status(raw: dict[str, Any]) -> dict[str, Any]: def _process_array_status(raw: dict[str, Any]) -> dict[str, Any]:
"""Process raw array data into summary + details.""" """Process raw array data into summary + details."""
summary: dict[str, Any] = {"state": raw.get("state")} summary: dict[str, Any] = {"state": raw.get("state")}
if raw.get("capacity") and raw["capacity"].get("kilobytes"): if raw.get("capacity") and raw["capacity"].get("kilobytes"):
kb = raw["capacity"]["kilobytes"] kb = raw["capacity"]["kilobytes"]
summary["capacity_total"] = format_kb(kb.get("total")) summary["capacity_total"] = _format_kb(kb.get("total"))
summary["capacity_used"] = format_kb(kb.get("used")) summary["capacity_used"] = _format_kb(kb.get("used"))
summary["capacity_free"] = format_kb(kb.get("free")) summary["capacity_free"] = _format_kb(kb.get("free"))
summary["num_parity_disks"] = len(raw.get("parities", [])) summary["num_parity_disks"] = len(raw.get("parities", []))
summary["num_data_disks"] = len(raw.get("disks", [])) summary["num_data_disks"] = len(raw.get("disks", []))
@@ -334,8 +345,8 @@ def register_info_tool(mcp: FastMCP) -> None:
ups_device - Single UPS device (requires device_id) ups_device - Single UPS device (requires device_id)
ups_config - UPS configuration ups_config - UPS configuration
""" """
if action not in ALL_ACTIONS: if action not in QUERIES:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") raise ToolError(f"Invalid action '{action}'. Must be one of: {list(QUERIES.keys())}")
if action == "ups_device" and not device_id: if action == "ups_device" and not device_id:
raise ToolError("device_id is required for ups_device action") raise ToolError("device_id is required for ups_device action")
@@ -366,7 +377,7 @@ def register_info_tool(mcp: FastMCP) -> None:
"ups_devices": ("upsDevices", "ups_devices"), "ups_devices": ("upsDevices", "ups_devices"),
} }
with tool_error_handler("info", action, logger): try:
logger.info(f"Executing unraid_info action={action}") logger.info(f"Executing unraid_info action={action}")
data = await make_graphql_request(query, variables) data = await make_graphql_request(query, variables)
@@ -415,9 +426,14 @@ def register_info_tool(mcp: FastMCP) -> None:
if action in list_actions: if action in list_actions:
response_key, output_key = list_actions[action] response_key, output_key = list_actions[action]
items = data.get(response_key) or [] items = data.get(response_key) or []
normalized_items = list(items) if isinstance(items, list) else [] return {output_key: list(items) if isinstance(items, list) else []}
return {output_key: normalized_items}
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_info action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute info/{action}: {e!s}") from e
logger.info("Info tool registered successfully") logger.info("Info tool registered successfully")

View File

@@ -4,13 +4,13 @@ Provides the `unraid_keys` tool with 5 actions for listing, viewing,
creating, updating, and deleting API keys. creating, updating, and deleting API keys.
""" """
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -45,7 +45,6 @@ MUTATIONS: dict[str, str] = {
} }
DESTRUCTIVE_ACTIONS = {"delete"} DESTRUCTIVE_ACTIONS = {"delete"}
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
KEY_ACTIONS = Literal[ KEY_ACTIONS = Literal[
"list", "list",
@@ -55,14 +54,6 @@ KEY_ACTIONS = Literal[
"delete", "delete",
] ]
if set(get_args(KEY_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(KEY_ACTIONS))
_extra = set(get_args(KEY_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"KEY_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_keys_tool(mcp: FastMCP) -> None: def register_keys_tool(mcp: FastMCP) -> None:
"""Register the unraid_keys tool with the FastMCP instance.""" """Register the unraid_keys tool with the FastMCP instance."""
@@ -85,13 +76,14 @@ def register_keys_tool(mcp: FastMCP) -> None:
update - Update an API key (requires key_id; optional name, roles) update - Update an API key (requires key_id; optional name, roles)
delete - Delete API keys (requires key_id, confirm=True) delete - Delete API keys (requires key_id, confirm=True)
""" """
if action not in ALL_ACTIONS: all_actions = set(QUERIES) | set(MUTATIONS)
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") if action not in all_actions:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
if action in DESTRUCTIVE_ACTIONS and not confirm: if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.") raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
with tool_error_handler("keys", action, logger): try:
logger.info(f"Executing unraid_keys action={action}") logger.info(f"Executing unraid_keys action={action}")
if action == "list": if action == "list":
@@ -149,4 +141,10 @@ def register_keys_tool(mcp: FastMCP) -> None:
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_keys action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute keys/{action}: {e!s}") from e
logger.info("Keys tool registered successfully") logger.info("Keys tool registered successfully")

View File

@@ -4,13 +4,13 @@ Provides the `unraid_notifications` tool with 9 actions for viewing,
creating, archiving, and deleting system notifications. creating, archiving, and deleting system notifications.
""" """
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -76,8 +76,6 @@ MUTATIONS: dict[str, str] = {
} }
DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"} DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"}
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
_VALID_IMPORTANCE = {"ALERT", "WARNING", "NORMAL"}
NOTIFICATION_ACTIONS = Literal[ NOTIFICATION_ACTIONS = Literal[
"overview", "overview",
@@ -91,14 +89,6 @@ NOTIFICATION_ACTIONS = Literal[
"archive_all", "archive_all",
] ]
if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(NOTIFICATION_ACTIONS))
_extra = set(get_args(NOTIFICATION_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"NOTIFICATION_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_notifications_tool(mcp: FastMCP) -> None: def register_notifications_tool(mcp: FastMCP) -> None:
"""Register the unraid_notifications tool with the FastMCP instance.""" """Register the unraid_notifications tool with the FastMCP instance."""
@@ -130,13 +120,16 @@ def register_notifications_tool(mcp: FastMCP) -> None:
delete_archived - Delete all archived notifications (requires confirm=True) delete_archived - Delete all archived notifications (requires confirm=True)
archive_all - Archive all notifications (optional importance filter) archive_all - Archive all notifications (optional importance filter)
""" """
if action not in ALL_ACTIONS: all_actions = {**QUERIES, **MUTATIONS}
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") if action not in all_actions:
raise ToolError(
f"Invalid action '{action}'. Must be one of: {list(all_actions.keys())}"
)
if action in DESTRUCTIVE_ACTIONS and not confirm: if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.") raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
with tool_error_handler("notifications", action, logger): try:
logger.info(f"Executing unraid_notifications action={action}") logger.info(f"Executing unraid_notifications action={action}")
if action == "overview": if action == "overview":
@@ -154,29 +147,18 @@ def register_notifications_tool(mcp: FastMCP) -> None:
filter_vars["importance"] = importance.upper() filter_vars["importance"] = importance.upper()
data = await make_graphql_request(QUERIES["list"], {"filter": filter_vars}) data = await make_graphql_request(QUERIES["list"], {"filter": filter_vars})
notifications = data.get("notifications", {}) notifications = data.get("notifications", {})
return {"notifications": notifications.get("list", [])} result = notifications.get("list", [])
return {"notifications": list(result) if isinstance(result, list) else []}
if action == "warnings": if action == "warnings":
data = await make_graphql_request(QUERIES["warnings"]) data = await make_graphql_request(QUERIES["warnings"])
notifications = data.get("notifications", {}) notifications = data.get("notifications", {})
return {"warnings": notifications.get("warningsAndAlerts", [])} result = notifications.get("warningsAndAlerts", [])
return {"warnings": list(result) if isinstance(result, list) else []}
if action == "create": if action == "create":
if title is None or subject is None or description is None or importance is None: if title is None or subject is None or description is None or importance is None:
raise ToolError("create requires title, subject, description, and importance") raise ToolError("create requires title, subject, description, and importance")
if importance.upper() not in _VALID_IMPORTANCE:
raise ToolError(
f"importance must be one of: {', '.join(sorted(_VALID_IMPORTANCE))}. "
f"Got: '{importance}'"
)
if len(title) > 200:
raise ToolError(f"title must be at most 200 characters (got {len(title)})")
if len(subject) > 500:
raise ToolError(f"subject must be at most 500 characters (got {len(subject)})")
if len(description) > 2000:
raise ToolError(
f"description must be at most 2000 characters (got {len(description)})"
)
input_data = { input_data = {
"title": title, "title": title,
"subject": subject, "subject": subject,
@@ -214,4 +196,10 @@ def register_notifications_tool(mcp: FastMCP) -> None:
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_notifications action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute notifications/{action}: {e!s}") from e
logger.info("Notifications tool registered successfully") logger.info("Notifications tool registered successfully")

View File

@@ -4,14 +4,13 @@ Provides the `unraid_rclone` tool with 4 actions for managing
cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.). cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.).
""" """
import re from typing import Any, Literal
from typing import Any, Literal, get_args
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -50,59 +49,6 @@ RCLONE_ACTIONS = Literal[
"delete_remote", "delete_remote",
] ]
if set(get_args(RCLONE_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(RCLONE_ACTIONS))
_extra = set(get_args(RCLONE_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"RCLONE_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
# Max config entries to prevent abuse
_MAX_CONFIG_KEYS = 50
# Pattern for suspicious key names (path traversal, shell metacharacters)
_DANGEROUS_KEY_PATTERN = re.compile(r"\.\.|[/\\;|`$(){}]")
# Max length for individual config values
_MAX_VALUE_LENGTH = 4096
def _validate_config_data(config_data: dict[str, Any]) -> dict[str, str]:
"""Validate and sanitize rclone config_data before passing to GraphQL.
Ensures all keys and values are safe strings with no injection vectors.
Raises:
ToolError: If config_data contains invalid keys or values
"""
if len(config_data) > _MAX_CONFIG_KEYS:
raise ToolError(f"config_data has {len(config_data)} keys (max {_MAX_CONFIG_KEYS})")
validated: dict[str, str] = {}
for key, value in config_data.items():
if not isinstance(key, str) or not key.strip():
raise ToolError(
f"config_data keys must be non-empty strings, got: {type(key).__name__}"
)
if _DANGEROUS_KEY_PATTERN.search(key):
raise ToolError(
f"config_data key '{key}' contains disallowed characters "
f"(path traversal or shell metacharacters)"
)
if not isinstance(value, (str, int, float, bool)):
raise ToolError(
f"config_data['{key}'] must be a string, number, or boolean, "
f"got: {type(value).__name__}"
)
str_value = str(value)
if len(str_value) > _MAX_VALUE_LENGTH:
raise ToolError(
f"config_data['{key}'] value exceeds max length "
f"({len(str_value)} > {_MAX_VALUE_LENGTH})"
)
validated[key] = str_value
return validated
def register_rclone_tool(mcp: FastMCP) -> None: def register_rclone_tool(mcp: FastMCP) -> None:
"""Register the unraid_rclone tool with the FastMCP instance.""" """Register the unraid_rclone tool with the FastMCP instance."""
@@ -129,7 +75,7 @@ def register_rclone_tool(mcp: FastMCP) -> None:
if action in DESTRUCTIVE_ACTIONS and not confirm: if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.") raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
with tool_error_handler("rclone", action, logger): try:
logger.info(f"Executing unraid_rclone action={action}") logger.info(f"Executing unraid_rclone action={action}")
if action == "list_remotes": if action == "list_remotes":
@@ -150,10 +96,9 @@ def register_rclone_tool(mcp: FastMCP) -> None:
if action == "create_remote": if action == "create_remote":
if name is None or provider_type is None or config_data is None: if name is None or provider_type is None or config_data is None:
raise ToolError("create_remote requires name, provider_type, and config_data") raise ToolError("create_remote requires name, provider_type, and config_data")
validated_config = _validate_config_data(config_data)
data = await make_graphql_request( data = await make_graphql_request(
MUTATIONS["create_remote"], MUTATIONS["create_remote"],
{"input": {"name": name, "type": provider_type, "config": validated_config}}, {"input": {"name": name, "type": provider_type, "config": config_data}},
) )
remote = data.get("rclone", {}).get("createRCloneRemote") remote = data.get("rclone", {}).get("createRCloneRemote")
if not remote: if not remote:
@@ -182,4 +127,10 @@ def register_rclone_tool(mcp: FastMCP) -> None:
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_rclone action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute rclone/{action}: {e!s}") from e
logger.info("RClone tool registered successfully") logger.info("RClone tool registered successfully")

View File

@@ -4,19 +4,17 @@ Provides the `unraid_storage` tool with 6 actions for shares, physical disks,
unassigned devices, log files, and log content retrieval. unassigned devices, log files, and log content retrieval.
""" """
import os from typing import Any, Literal
from typing import Any, Literal, get_args
import anyio
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import DISK_TIMEOUT, make_graphql_request from ..core.client import DISK_TIMEOUT, make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
from ..core.utils import format_bytes
_ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/", "/mnt/") _ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/", "/mnt/")
_MAX_TAIL_LINES = 10_000
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
"shares": """ "shares": """
@@ -58,8 +56,6 @@ QUERIES: dict[str, str] = {
""", """,
} }
ALL_ACTIONS = set(QUERIES)
STORAGE_ACTIONS = Literal[ STORAGE_ACTIONS = Literal[
"shares", "shares",
"disks", "disks",
@@ -69,13 +65,20 @@ STORAGE_ACTIONS = Literal[
"logs", "logs",
] ]
if set(get_args(STORAGE_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(STORAGE_ACTIONS)) def format_bytes(bytes_value: int | None) -> str:
_extra = set(get_args(STORAGE_ACTIONS)) - ALL_ACTIONS """Format byte values into human-readable sizes."""
raise RuntimeError( if bytes_value is None:
f"STORAGE_ACTIONS and ALL_ACTIONS are out of sync. " return "N/A"
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}" try:
) value = float(int(bytes_value))
except (ValueError, TypeError):
return "N/A"
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if value < 1024.0:
return f"{value:.2f} {unit}"
value /= 1024.0
return f"{value:.2f} EB"
def register_storage_tool(mcp: FastMCP) -> None: def register_storage_tool(mcp: FastMCP) -> None:
@@ -98,22 +101,17 @@ def register_storage_tool(mcp: FastMCP) -> None:
log_files - List available log files log_files - List available log files
logs - Retrieve log content (requires log_path, optional tail_lines) logs - Retrieve log content (requires log_path, optional tail_lines)
""" """
if action not in ALL_ACTIONS: if action not in QUERIES:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") raise ToolError(f"Invalid action '{action}'. Must be one of: {list(QUERIES.keys())}")
if action == "disk_details" and not disk_id: if action == "disk_details" and not disk_id:
raise ToolError("disk_id is required for 'disk_details' action") raise ToolError("disk_id is required for 'disk_details' action")
if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
if action == "logs": if action == "logs":
if not log_path: if not log_path:
raise ToolError("log_path is required for 'logs' action") raise ToolError("log_path is required for 'logs' action")
# Resolve path synchronously to prevent traversal attacks. # Resolve path to prevent traversal attacks (e.g. /var/log/../../etc/shadow)
# Using os.path.realpath instead of anyio.Path.resolve() because the normalized = str(await anyio.Path(log_path).resolve())
# async variant blocks on NFS-mounted paths under /mnt/ (Perf-AI-1).
normalized = os.path.realpath(log_path) # noqa: ASYNC240
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES): if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
raise ToolError( raise ToolError(
f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}. " f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}. "
@@ -130,15 +128,17 @@ def register_storage_tool(mcp: FastMCP) -> None:
elif action == "logs": elif action == "logs":
variables = {"path": log_path, "lines": tail_lines} variables = {"path": log_path, "lines": tail_lines}
with tool_error_handler("storage", action, logger): try:
logger.info(f"Executing unraid_storage action={action}") logger.info(f"Executing unraid_storage action={action}")
data = await make_graphql_request(query, variables, custom_timeout=custom_timeout) data = await make_graphql_request(query, variables, custom_timeout=custom_timeout)
if action == "shares": if action == "shares":
return {"shares": data.get("shares", [])} shares = data.get("shares", [])
return {"shares": list(shares) if isinstance(shares, list) else []}
if action == "disks": if action == "disks":
return {"disks": data.get("disks", [])} disks = data.get("disks", [])
return {"disks": list(disks) if isinstance(disks, list) else []}
if action == "disk_details": if action == "disk_details":
raw = data.get("disk", {}) raw = data.get("disk", {})
@@ -159,14 +159,22 @@ def register_storage_tool(mcp: FastMCP) -> None:
return {"summary": summary, "details": raw} return {"summary": summary, "details": raw}
if action == "unassigned": if action == "unassigned":
return {"devices": data.get("unassignedDevices", [])} devices = data.get("unassignedDevices", [])
return {"devices": list(devices) if isinstance(devices, list) else []}
if action == "log_files": if action == "log_files":
return {"log_files": data.get("logFiles", [])} files = data.get("logFiles", [])
return {"log_files": list(files) if isinstance(files, list) else []}
if action == "logs": if action == "logs":
return dict(data.get("logFile") or {}) return dict(data.get("logFile") or {})
raise ToolError(f"Unhandled action '{action}' — this is a bug") raise ToolError(f"Unhandled action '{action}' — this is a bug")
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_storage action={action}: {e}", exc_info=True)
raise ToolError(f"Failed to execute storage/{action}: {e!s}") from e
logger.info("Storage tool registered successfully") logger.info("Storage tool registered successfully")

View File

@@ -10,7 +10,7 @@ from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -39,11 +39,17 @@ def register_users_tool(mcp: FastMCP) -> None:
Note: Unraid API does not support user management operations (list, add, delete). Note: Unraid API does not support user management operations (list, add, delete).
""" """
if action not in ALL_ACTIONS: if action not in ALL_ACTIONS:
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}") raise ToolError(f"Invalid action '{action}'. Must be: me")
with tool_error_handler("users", action, logger): try:
logger.info("Executing unraid_users action=me") logger.info("Executing unraid_users action=me")
data = await make_graphql_request(QUERIES["me"]) data = await make_graphql_request(QUERIES["me"])
return data.get("me") or {} return data.get("me") or {}
except ToolError:
raise
except Exception as e:
logger.error(f"Error in unraid_users action=me: {e}", exc_info=True)
raise ToolError(f"Failed to execute users/me: {e!s}") from e
logger.info("Users tool registered successfully") logger.info("Users tool registered successfully")

View File

@@ -4,13 +4,13 @@ Provides the `unraid_vm` tool with 9 actions for VM lifecycle management
including start, stop, pause, resume, force stop, reboot, and reset. including start, stop, pause, resume, force stop, reboot, and reset.
""" """
from typing import Any, Literal, get_args from typing import Any, Literal
from fastmcp import FastMCP from fastmcp import FastMCP
from ..config.logging import logger from ..config.logging import logger
from ..core.client import make_graphql_request from ..core.client import make_graphql_request
from ..core.exceptions import ToolError, tool_error_handler from ..core.exceptions import ToolError
QUERIES: dict[str, str] = { QUERIES: dict[str, str] = {
@@ -19,13 +19,6 @@ QUERIES: dict[str, str] = {
vms { id domains { id name state uuid } } vms { id domains { id name state uuid } }
} }
""", """,
# NOTE: The Unraid GraphQL API does not expose a single-VM query.
# The details query is identical to list; client-side filtering is required.
"details": """
query ListVMs {
vms { id domains { id name state uuid } }
}
""",
} }
MUTATIONS: dict[str, str] = { MUTATIONS: dict[str, str] = {
@@ -71,15 +64,7 @@ VM_ACTIONS = Literal[
"reset", "reset",
] ]
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"details"}
if set(get_args(VM_ACTIONS)) != ALL_ACTIONS:
_missing = ALL_ACTIONS - set(get_args(VM_ACTIONS))
_extra = set(get_args(VM_ACTIONS)) - ALL_ACTIONS
raise RuntimeError(
f"VM_ACTIONS and ALL_ACTIONS are out of sync. "
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
)
def register_vm_tool(mcp: FastMCP) -> None: def register_vm_tool(mcp: FastMCP) -> None:
@@ -113,26 +98,20 @@ def register_vm_tool(mcp: FastMCP) -> None:
if action in DESTRUCTIVE_ACTIONS and not confirm: if action in DESTRUCTIVE_ACTIONS and not confirm:
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.") raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
with tool_error_handler("vm", action, logger):
try: try:
logger.info(f"Executing unraid_vm action={action}") logger.info(f"Executing unraid_vm action={action}")
if action == "list": if action in ("list", "details"):
data = await make_graphql_request(QUERIES["list"]) data = await make_graphql_request(QUERIES["list"])
if data.get("vms"): if data.get("vms"):
vms = data["vms"].get("domains") or data["vms"].get("domain") or [] vms = data["vms"].get("domains") or data["vms"].get("domain") or []
if isinstance(vms, dict): if isinstance(vms, dict):
vms = [vms] vms = [vms]
return {"vms": vms}
return {"vms": []}
if action == "details": if action == "list":
data = await make_graphql_request(QUERIES["details"]) return {"vms": vms}
if not data.get("vms"):
raise ToolError("No VM data returned from server") # details: find specific VM
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
if isinstance(vms, dict):
vms = [vms]
for vm in vms: for vm in vms:
if ( if (
vm.get("uuid") == vm_id vm.get("uuid") == vm_id
@@ -142,6 +121,9 @@ def register_vm_tool(mcp: FastMCP) -> None:
return dict(vm) return dict(vm)
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms] available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}") raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
if action == "details":
raise ToolError("No VM data returned from server")
return {"vms": []}
# Mutations # Mutations
if action in MUTATIONS: if action in MUTATIONS:
@@ -160,10 +142,12 @@ def register_vm_tool(mcp: FastMCP) -> None:
except ToolError: except ToolError:
raise raise
except Exception as e: except Exception as e:
if "VMs are not available" in str(e): logger.error(f"Error in unraid_vm action={action}: {e}", exc_info=True)
msg = str(e)
if "VMs are not available" in msg:
raise ToolError( raise ToolError(
"VMs not available on this server. Check VM support is enabled." "VMs not available on this server. Check VM support is enabled."
) from e ) from e
raise raise ToolError(f"Failed to execute vm/{action}: {msg}") from e
logger.info("VM tool registered successfully") logger.info("VM tool registered successfully")

View File

@@ -1,11 +0,0 @@
"""Application version helpers."""
from importlib.metadata import PackageNotFoundError, version
__all__ = ["VERSION"]
try:
VERSION = version("unraid-mcp")
except PackageNotFoundError:
VERSION = "0.0.0"

13
uv.lock generated
View File

@@ -1706,6 +1706,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" }, { url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" },
] ]
[[package]]
name = "types-pytz"
version = "2025.2.0.20251108"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/40/ff/c047ddc68c803b46470a357454ef76f4acd8c1088f5cc4891cdd909bfcf6/types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb", size = 10961, upload-time = "2025-11-08T02:55:57.001Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c", size = 10116, upload-time = "2025-11-08T02:55:56.194Z" },
]
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.15.0" version = "4.15.0"
@@ -1736,6 +1745,7 @@ dependencies = [
{ name = "fastmcp" }, { name = "fastmcp" },
{ name = "httpx" }, { name = "httpx" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "pytz" },
{ name = "rich" }, { name = "rich" },
{ name = "uvicorn", extra = ["standard"] }, { name = "uvicorn", extra = ["standard"] },
{ name = "websockets" }, { name = "websockets" },
@@ -1752,6 +1762,7 @@ dev = [
{ name = "ruff" }, { name = "ruff" },
{ name = "twine" }, { name = "twine" },
{ name = "ty" }, { name = "ty" },
{ name = "types-pytz" },
] ]
[package.metadata] [package.metadata]
@@ -1760,6 +1771,7 @@ requires-dist = [
{ name = "fastmcp", specifier = ">=2.14.5" }, { name = "fastmcp", specifier = ">=2.14.5" },
{ name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", specifier = ">=0.28.1" },
{ name = "python-dotenv", specifier = ">=1.1.1" }, { name = "python-dotenv", specifier = ">=1.1.1" },
{ name = "pytz", specifier = ">=2025.2" },
{ name = "rich", specifier = ">=14.1.0" }, { name = "rich", specifier = ">=14.1.0" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.35.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.35.0" },
{ name = "websockets", specifier = ">=15.0.1" }, { name = "websockets", specifier = ">=15.0.1" },
@@ -1776,6 +1788,7 @@ dev = [
{ name = "ruff", specifier = ">=0.12.8" }, { name = "ruff", specifier = ">=0.12.8" },
{ name = "twine", specifier = ">=6.0.1" }, { name = "twine", specifier = ">=6.0.1" },
{ name = "ty", specifier = ">=0.0.15" }, { name = "ty", specifier = ">=0.0.15" },
{ name = "types-pytz", specifier = ">=2025.2.0.20250809" },
] ]
[[package]] [[package]]