forked from HomeLab/unraid-mcp
Compare commits
5 Commits
0.0.1
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a5b19c42f | ||
|
|
1751bc2984 | ||
|
|
348f4149a5 | ||
|
|
f76e676fd4 | ||
|
|
316193c04b |
@@ -1,50 +0,0 @@
|
|||||||
#!/usr/bin/env sh
|
|
||||||
echo "Running commit message checks..."
|
|
||||||
|
|
||||||
. "$(dirname -- "$0")/../../.gitea/conventional_commits/hooks/text-styles.sh"
|
|
||||||
|
|
||||||
|
|
||||||
# Get the commit message
|
|
||||||
commit="$(cat .git/COMMIT_EDITMSG)"
|
|
||||||
# Define the conventional commit regex
|
|
||||||
regex='^((build|chore|ci|docs|feat|fix|perf|refactor|revert|style|test)(\(.+\))?(!?):\s([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(NOISSUE))?))|(release: .*)$'
|
|
||||||
|
|
||||||
# Check if the commit message matches the conventional commit format
|
|
||||||
if ! echo "$commit" | grep -Pq "$regex"
|
|
||||||
then
|
|
||||||
echo
|
|
||||||
colorPrint red "❌ Failed to create commit. Your commit message does not follow the conventional commit format."
|
|
||||||
colorPrint red "Please use the following format: $(colorPrint brightRed 'type(scope)?: description')"
|
|
||||||
colorPrint red "Available types are listed below. Scope is optional. Use ! after type to indicate breaking change."
|
|
||||||
echo
|
|
||||||
colorPrint brightWhite "Quick examples:"
|
|
||||||
echo "feat: add email notifications on new direct messages refs ABC-1213"
|
|
||||||
echo "feat(shopping cart): add the amazing button ref: DEFG-23"
|
|
||||||
echo "feat!: remove ticket list endpoint ref DADA-109"
|
|
||||||
echo "fix(api): handle empty message in request body refs: MINE-82"
|
|
||||||
echo "chore(deps): bump some-package-name to version 2.0.0 refs ASDF-12"
|
|
||||||
echo
|
|
||||||
colorPrint brightWhite "Commit types:"
|
|
||||||
colorPrint brightCyan "build: $(colorPrint white "Changes that affect the build system or external dependencies (example scopes: gulp, broccoli, npm)" -n)"
|
|
||||||
colorPrint brightCyan "ci: $(colorPrint white "Changes to CI configuration files and scripts (example scopes: Travis, Circle, BrowserStack, SauceLabs)" -n)"
|
|
||||||
colorPrint brightCyan "chore: $(colorPrint white "Changes which doesn't change source code or tests e.g. changes to the build process, auxiliary tools, libraries" -n)"
|
|
||||||
colorPrint brightCyan "docs: $(colorPrint white "Documentation only changes" -n)"
|
|
||||||
colorPrint brightCyan "feat: $(colorPrint white "A new feature" -n)"
|
|
||||||
colorPrint brightCyan "fix: $(colorPrint white "A bug fix" -n)"
|
|
||||||
colorPrint brightCyan "perf: $(colorPrint white "A code change that improves performance" -n)"
|
|
||||||
colorPrint brightCyan "refactor: $(colorPrint white "A code change that neither fixes a bug nor adds a feature" -n)"
|
|
||||||
colorPrint brightCyan "revert: $(colorPrint white "Revert a change previously introduced" -n)"
|
|
||||||
colorPrint brightCyan "test: $(colorPrint white "Adding missing tests or correcting existing tests" -n)"
|
|
||||||
echo
|
|
||||||
|
|
||||||
colorPrint brightWhite "Reminders"
|
|
||||||
echo "Put newline before extended commit body"
|
|
||||||
echo "More details at $(underline "http://www.conventionalcommits.org")"
|
|
||||||
echo
|
|
||||||
echo "The commit message you attempted was: $commit"
|
|
||||||
echo
|
|
||||||
echo "The exact RegEx applied to this message was:"
|
|
||||||
colorPrint brightCyan "$regex"
|
|
||||||
echo
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Rules for generating semantic versioning
|
|
||||||
# major: breaking change
|
|
||||||
# minor: feat, style
|
|
||||||
# patch: build, fix, perf, refactor, revert
|
|
||||||
|
|
||||||
PREVENT_REMOVE_FILE=$1
|
|
||||||
TEMP_FILE_PATH=.gitea/conventional_commits/tmp
|
|
||||||
|
|
||||||
LAST_TAG=$(git describe --tags --abbrev=0 --always)
|
|
||||||
echo "Last tag: #$LAST_TAG#"
|
|
||||||
PATTERN="^[0-9]+\.[0-9]+\.[0-9]+$"
|
|
||||||
|
|
||||||
increment_version() {
|
|
||||||
local version=$1
|
|
||||||
local increment=$2
|
|
||||||
local major=$(echo $version | cut -d. -f1)
|
|
||||||
local minor=$(echo $version | cut -d. -f2)
|
|
||||||
local patch=$(echo $version | cut -d. -f3)
|
|
||||||
|
|
||||||
if [ "$increment" == "major" ]; then
|
|
||||||
major=$((major + 1))
|
|
||||||
minor=0
|
|
||||||
patch=0
|
|
||||||
elif [ "$increment" == "minor" ]; then
|
|
||||||
minor=$((minor + 1))
|
|
||||||
patch=0
|
|
||||||
elif [ "$increment" == "patch" ]; then
|
|
||||||
patch=$((patch + 1))
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "${major}.${minor}.${patch}"
|
|
||||||
}
|
|
||||||
|
|
||||||
create_file() {
|
|
||||||
local with_range=$1
|
|
||||||
if [ -s $TEMP_FILE_PATH/messages.txt ]; then
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
if [ "$with_range" == "true" ]; then
|
|
||||||
git log $LAST_TAG..HEAD --no-decorate --pretty=format:"%s" > $TEMP_FILE_PATH/messages.txt
|
|
||||||
else
|
|
||||||
git log --no-decorate --pretty=format:"%s" > $TEMP_FILE_PATH/messages.txt
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
get_commit_range() {
|
|
||||||
rm -f $TEMP_FILE_PATH/messages.txt
|
|
||||||
if [[ $LAST_TAG =~ $PATTERN ]]; then
|
|
||||||
create_file true
|
|
||||||
else
|
|
||||||
create_file
|
|
||||||
LAST_TAG="0.0.0"
|
|
||||||
fi
|
|
||||||
echo " " >> $TEMP_FILE_PATH/messages.txt
|
|
||||||
}
|
|
||||||
|
|
||||||
start() {
|
|
||||||
mkdir -p $TEMP_FILE_PATH
|
|
||||||
get_commit_range
|
|
||||||
new_version=$LAST_TAG
|
|
||||||
increment_type=""
|
|
||||||
|
|
||||||
while read message; do
|
|
||||||
echo $message
|
|
||||||
if echo $message | grep -Pq '(feat|style)(\([\w]+\))?!:([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(#[0-9]+)|(NOISSUE))?)'; then
|
|
||||||
increment_type="major"
|
|
||||||
echo "a"
|
|
||||||
break
|
|
||||||
elif echo $message | grep -Pq '(feat|style)(\([\w]+\))?:([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(#[0-9]+)|(NOISSUE))?)'; then
|
|
||||||
if [ -z "$increment_type" ] || [ "$increment_type" == "patch" ]; then
|
|
||||||
increment_type="minor"
|
|
||||||
echo "b"
|
|
||||||
fi
|
|
||||||
elif echo $message | grep -Pq '(build|fix|perf|refactor|revert)(\(.+\))?:\s([a-zA-Z0-9-_!\&\.\%\(\)\=\w\s]+)\s?(,?\s?)((ref(s?):?\s?)(([A-Z0-9]+\-[0-9]+)|(#[0-9]+)|(NOISSUE))?)'; then
|
|
||||||
if [ -z "$increment_type" ]; then
|
|
||||||
increment_type="patch"
|
|
||||||
echo "c"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done < $TEMP_FILE_PATH/messages.txt
|
|
||||||
|
|
||||||
if [ -n "$increment_type" ]; then
|
|
||||||
new_version=$(increment_version $LAST_TAG $increment_type)
|
|
||||||
echo "New version: $new_version"
|
|
||||||
|
|
||||||
gitchangelog | grep -v "[rR]elease:" > HISTORY.md
|
|
||||||
git add HISTORY.md
|
|
||||||
echo $new_version > VERSION
|
|
||||||
git add VERSION
|
|
||||||
git commit -m "release: version $new_version 🚀"
|
|
||||||
echo "creating git tag : $new_version"
|
|
||||||
git tag $new_version
|
|
||||||
git push -u origin HEAD --tags
|
|
||||||
echo "Gitea Actions will detect the new tag and release the new version."
|
|
||||||
else
|
|
||||||
echo "No changes requiring a version increment."
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
start
|
|
||||||
|
|
||||||
if [ -z "$PREVENT_REMOVE_FILE" ]; then
|
|
||||||
rm -f $TEMP_FILE_PATH/messages.txt
|
|
||||||
fi
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
colorPrint() {
|
|
||||||
local color=$1
|
|
||||||
local text=$2
|
|
||||||
shift 2
|
|
||||||
local newline="\n"
|
|
||||||
local tab=""
|
|
||||||
|
|
||||||
for arg in "$@"
|
|
||||||
do
|
|
||||||
if [ "$arg" = "-t" ]; then
|
|
||||||
tab="\t"
|
|
||||||
elif [ "$arg" = "-n" ]; then
|
|
||||||
newline=""
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
case $color in
|
|
||||||
black) color_code="30" ;;
|
|
||||||
red) color_code="31" ;;
|
|
||||||
green) color_code="32" ;;
|
|
||||||
yellow) color_code="33" ;;
|
|
||||||
blue) color_code="34" ;;
|
|
||||||
magenta) color_code="35" ;;
|
|
||||||
cyan) color_code="36" ;;
|
|
||||||
white) color_code="37" ;;
|
|
||||||
brightBlack) color_code="90" ;;
|
|
||||||
brightRed) color_code="91" ;;
|
|
||||||
brightGreen) color_code="92" ;;
|
|
||||||
brightYellow) color_code="93" ;;
|
|
||||||
brightBlue) color_code="94" ;;
|
|
||||||
brightMagenta) color_code="95" ;;
|
|
||||||
brightCyan) color_code="96" ;;
|
|
||||||
brightWhite) color_code="97" ;;
|
|
||||||
*) echo "Invalid color"; return ;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
printf "\e[${color_code}m${tab}%s\e[0m${newline}" "$text"
|
|
||||||
}
|
|
||||||
|
|
||||||
underline () {
|
|
||||||
printf "\033[4m%s\033[24m" "$1"
|
|
||||||
}
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
# generates changelog since last release
|
|
||||||
previous_tag=$(git tag --sort=-creatordate | sed -n 2p)
|
|
||||||
git shortlog "${previous_tag}.." | sed 's/^./ &/'
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
name: Build Docker image
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
env:
|
|
||||||
SKIP_MAKE_SETUP_CHECK: 'true'
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
# Sequence of patterns matched against refs/tags
|
|
||||||
tags:
|
|
||||||
- '*' # Push events to matching v*, i.e. v1.0, v20.15.10
|
|
||||||
|
|
||||||
# Allows you to run this workflow manually from the Actions tab
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
release:
|
|
||||||
name: Create Release
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
with:
|
|
||||||
# by default, it uses a depth of 1
|
|
||||||
# this fetches all history so that we can read each commit
|
|
||||||
fetch-depth: 0
|
|
||||||
- name: Generate Changelog
|
|
||||||
run: .gitea/release_message.sh > release_message.md
|
|
||||||
- name: Release
|
|
||||||
uses: softprops/action-gh-release@v1
|
|
||||||
with:
|
|
||||||
body_path: release_message.md
|
|
||||||
|
|
||||||
deploy:
|
|
||||||
needs: release
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- name: Check version match
|
|
||||||
run: |
|
|
||||||
REPOSITORY_NAME=$(echo "$GITHUB_REPOSITORY" | awk -F '/' '{print $2}' | tr '-' '_')
|
|
||||||
if [ "$(cat VERSION)" = "${GITHUB_REF_NAME}" ] ; then
|
|
||||||
echo "Version matches successfully!"
|
|
||||||
else
|
|
||||||
echo "Version must match!"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
- name: Login to Gitea container registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: gitearobot
|
|
||||||
password: ${{ secrets.PACKAGE_GITEA_PAT }}
|
|
||||||
registry: git.disi.dev
|
|
||||||
- name: Build and publish
|
|
||||||
run: |
|
|
||||||
REPOSITORY_OWNER=$(echo "$GITHUB_REPOSITORY" | awk -F '/' '{print $1}' | tr '[:upper:]' '[:lower:]')
|
|
||||||
REPOSITORY_NAME=$(echo "$GITHUB_REPOSITORY" | awk -F '/' '{print $2}' | tr '-' '_')
|
|
||||||
docker build -t "git.disi.dev/$REPOSITORY_OWNER/unraid-mcp:$(cat VERSION)" ./
|
|
||||||
docker push "git.disi.dev/$REPOSITORY_OWNER/unraid-mcp:$(cat VERSION)"
|
|
||||||
28
Dockerfile
28
Dockerfile
@@ -1,5 +1,5 @@
|
|||||||
# Use an official Python runtime as a parent image
|
# Use an official Python runtime as a parent image
|
||||||
FROM python:3.11-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
# Set the working directory in the container
|
# Set the working directory in the container
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
@@ -7,14 +7,22 @@ WORKDIR /app
|
|||||||
# Install uv
|
# Install uv
|
||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||||
|
|
||||||
# Copy dependency files
|
# Create non-root user with home directory and give ownership of /app
|
||||||
COPY pyproject.toml .
|
RUN groupadd --gid 1000 appuser && \
|
||||||
COPY uv.lock .
|
useradd --uid 1000 --gid 1000 --create-home --shell /bin/false appuser && \
|
||||||
COPY README.md .
|
chown appuser:appuser /app
|
||||||
COPY LICENSE .
|
|
||||||
|
# Copy dependency files (owned by appuser via --chown)
|
||||||
|
COPY --chown=appuser:appuser pyproject.toml .
|
||||||
|
COPY --chown=appuser:appuser uv.lock .
|
||||||
|
COPY --chown=appuser:appuser README.md .
|
||||||
|
COPY --chown=appuser:appuser LICENSE .
|
||||||
|
|
||||||
# Copy the source code
|
# Copy the source code
|
||||||
COPY unraid_mcp/ ./unraid_mcp/
|
COPY --chown=appuser:appuser unraid_mcp/ ./unraid_mcp/
|
||||||
|
|
||||||
|
# Switch to non-root user before installing dependencies
|
||||||
|
USER appuser
|
||||||
|
|
||||||
# Install dependencies and the package
|
# Install dependencies and the package
|
||||||
RUN uv sync --frozen
|
RUN uv sync --frozen
|
||||||
@@ -32,5 +40,9 @@ ENV UNRAID_API_KEY=""
|
|||||||
ENV UNRAID_VERIFY_SSL="true"
|
ENV UNRAID_VERIFY_SSL="true"
|
||||||
ENV UNRAID_MCP_LOG_LEVEL="INFO"
|
ENV UNRAID_MCP_LOG_LEVEL="INFO"
|
||||||
|
|
||||||
# Run unraid-mcp-server.py when the container launches
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||||
|
CMD ["python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:6970/mcp')"]
|
||||||
|
|
||||||
|
# Run unraid-mcp-server when the container launches
|
||||||
CMD ["uv", "run", "unraid-mcp-server"]
|
CMD ["uv", "run", "unraid-mcp-server"]
|
||||||
|
|||||||
26
Makefile
26
Makefile
@@ -1,26 +0,0 @@
|
|||||||
.ONESHELL:
|
|
||||||
|
|
||||||
VERSION ?= $(shell cat ./VERSION)
|
|
||||||
|
|
||||||
.PHONY: issetup
|
|
||||||
issetup:
|
|
||||||
@[ -f .git/hooks/commit-msg ] || [ $SKIP_MAKE_SETUP_CHECK = "true" ] || (echo "You must run 'make setup' first to initialize the repo!" && exit 1)
|
|
||||||
|
|
||||||
.PHONY: setup
|
|
||||||
setup:
|
|
||||||
@cp .gitea/conventional_commits/commit-msg .git/hooks/
|
|
||||||
|
|
||||||
.PHONY: help
|
|
||||||
help: ## Show the help.
|
|
||||||
@echo "Usage: make <target>"
|
|
||||||
@echo ""
|
|
||||||
@echo "Targets:"
|
|
||||||
@fgrep "##" Makefile | fgrep -v fgrep
|
|
||||||
|
|
||||||
.PHONY: release
|
|
||||||
release: issetup ## Create a new tag for release.
|
|
||||||
@./.gitea/conventional_commits/generate-version.sh
|
|
||||||
|
|
||||||
.PHONY: build
|
|
||||||
build: issetup
|
|
||||||
@docker build -t unraid-mcp:${VERSION} .
|
|
||||||
@@ -5,6 +5,11 @@ services:
|
|||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
container_name: unraid-mcp
|
container_name: unraid-mcp
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
read_only: true
|
||||||
|
cap_drop:
|
||||||
|
- ALL
|
||||||
|
tmpfs:
|
||||||
|
- /tmp:noexec,nosuid,size=64m
|
||||||
ports:
|
ports:
|
||||||
# HostPort:ContainerPort (maps to UNRAID_MCP_PORT inside the container, default 6970)
|
# HostPort:ContainerPort (maps to UNRAID_MCP_PORT inside the container, default 6970)
|
||||||
# Change the host port (left side) if 6970 is already in use on your host
|
# Change the host port (left side) if 6970 is already in use on your host
|
||||||
@@ -13,23 +18,23 @@ services:
|
|||||||
# Core API Configuration (Required)
|
# Core API Configuration (Required)
|
||||||
- UNRAID_API_URL=${UNRAID_API_URL}
|
- UNRAID_API_URL=${UNRAID_API_URL}
|
||||||
- UNRAID_API_KEY=${UNRAID_API_KEY}
|
- UNRAID_API_KEY=${UNRAID_API_KEY}
|
||||||
|
|
||||||
# MCP Server Settings
|
# MCP Server Settings
|
||||||
- UNRAID_MCP_PORT=${UNRAID_MCP_PORT:-6970}
|
- UNRAID_MCP_PORT=${UNRAID_MCP_PORT:-6970}
|
||||||
- UNRAID_MCP_HOST=${UNRAID_MCP_HOST:-0.0.0.0}
|
- UNRAID_MCP_HOST=${UNRAID_MCP_HOST:-0.0.0.0}
|
||||||
- UNRAID_MCP_TRANSPORT=${UNRAID_MCP_TRANSPORT:-streamable-http}
|
- UNRAID_MCP_TRANSPORT=${UNRAID_MCP_TRANSPORT:-streamable-http}
|
||||||
|
|
||||||
# SSL Configuration
|
# SSL Configuration
|
||||||
- UNRAID_VERIFY_SSL=${UNRAID_VERIFY_SSL:-true}
|
- UNRAID_VERIFY_SSL=${UNRAID_VERIFY_SSL:-true}
|
||||||
|
|
||||||
# Logging Configuration
|
# Logging Configuration
|
||||||
- UNRAID_MCP_LOG_LEVEL=${UNRAID_MCP_LOG_LEVEL:-INFO}
|
- UNRAID_MCP_LOG_LEVEL=${UNRAID_MCP_LOG_LEVEL:-INFO}
|
||||||
- UNRAID_MCP_LOG_FILE=${UNRAID_MCP_LOG_FILE:-unraid-mcp.log}
|
- UNRAID_MCP_LOG_FILE=${UNRAID_MCP_LOG_FILE:-unraid-mcp.log}
|
||||||
|
|
||||||
# Real-time Subscription Configuration
|
# Real-time Subscription Configuration
|
||||||
- UNRAID_AUTO_START_SUBSCRIPTIONS=${UNRAID_AUTO_START_SUBSCRIPTIONS:-true}
|
- UNRAID_AUTO_START_SUBSCRIPTIONS=${UNRAID_AUTO_START_SUBSCRIPTIONS:-true}
|
||||||
- UNRAID_MAX_RECONNECT_ATTEMPTS=${UNRAID_MAX_RECONNECT_ATTEMPTS:-10}
|
- UNRAID_MAX_RECONNECT_ATTEMPTS=${UNRAID_MAX_RECONNECT_ATTEMPTS:-10}
|
||||||
|
|
||||||
# Optional: Custom log file path for subscription auto-start diagnostics
|
# Optional: Custom log file path for subscription auto-start diagnostics
|
||||||
- UNRAID_AUTOSTART_LOG_PATH=${UNRAID_AUTOSTART_LOG_PATH}
|
- UNRAID_AUTOSTART_LOG_PATH=${UNRAID_AUTOSTART_LOG_PATH}
|
||||||
# Optional: If you want to mount a specific directory for logs (ensure UNRAID_MCP_LOG_FILE points within this mount)
|
# Optional: If you want to mount a specific directory for logs (ensure UNRAID_MCP_LOG_FILE points within this mount)
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ dependencies = [
|
|||||||
"uvicorn[standard]>=0.35.0",
|
"uvicorn[standard]>=0.35.0",
|
||||||
"websockets>=15.0.1",
|
"websockets>=15.0.1",
|
||||||
"rich>=14.1.0",
|
"rich>=14.1.0",
|
||||||
"pytz>=2025.2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -170,6 +169,8 @@ select = [
|
|||||||
"PERF",
|
"PERF",
|
||||||
# Ruff-specific rules
|
# Ruff-specific rules
|
||||||
"RUF",
|
"RUF",
|
||||||
|
# flake8-bandit (security)
|
||||||
|
"S",
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"E501", # line too long (handled by ruff formatter)
|
"E501", # line too long (handled by ruff formatter)
|
||||||
@@ -285,7 +286,6 @@ dev = [
|
|||||||
"pytest-asyncio>=1.2.0",
|
"pytest-asyncio>=1.2.0",
|
||||||
"pytest-cov>=7.0.0",
|
"pytest-cov>=7.0.0",
|
||||||
"respx>=0.22.0",
|
"respx>=0.22.0",
|
||||||
"types-pytz>=2025.2.0.20250809",
|
|
||||||
"ty>=0.0.15",
|
"ty>=0.0.15",
|
||||||
"ruff>=0.12.8",
|
"ruff>=0.12.8",
|
||||||
"build>=1.2.2",
|
"build>=1.2.2",
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from tests.conftest import make_tool_fn
|
|||||||
from unraid_mcp.core.client import DEFAULT_TIMEOUT, DISK_TIMEOUT, make_graphql_request
|
from unraid_mcp.core.client import DEFAULT_TIMEOUT, DISK_TIMEOUT, make_graphql_request
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Shared fixtures
|
# Shared fixtures
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -158,43 +159,43 @@ class TestHttpErrorHandling:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_http_401_raises_tool_error(self) -> None:
|
async def test_http_401_raises_tool_error(self) -> None:
|
||||||
respx.post(API_URL).mock(return_value=httpx.Response(401, text="Unauthorized"))
|
respx.post(API_URL).mock(return_value=httpx.Response(401, text="Unauthorized"))
|
||||||
with pytest.raises(ToolError, match="HTTP error 401"):
|
with pytest.raises(ToolError, match="Unraid API returned HTTP 401"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_http_403_raises_tool_error(self) -> None:
|
async def test_http_403_raises_tool_error(self) -> None:
|
||||||
respx.post(API_URL).mock(return_value=httpx.Response(403, text="Forbidden"))
|
respx.post(API_URL).mock(return_value=httpx.Response(403, text="Forbidden"))
|
||||||
with pytest.raises(ToolError, match="HTTP error 403"):
|
with pytest.raises(ToolError, match="Unraid API returned HTTP 403"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_http_500_raises_tool_error(self) -> None:
|
async def test_http_500_raises_tool_error(self) -> None:
|
||||||
respx.post(API_URL).mock(return_value=httpx.Response(500, text="Internal Server Error"))
|
respx.post(API_URL).mock(return_value=httpx.Response(500, text="Internal Server Error"))
|
||||||
with pytest.raises(ToolError, match="HTTP error 500"):
|
with pytest.raises(ToolError, match="Unraid API returned HTTP 500"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_http_503_raises_tool_error(self) -> None:
|
async def test_http_503_raises_tool_error(self) -> None:
|
||||||
respx.post(API_URL).mock(return_value=httpx.Response(503, text="Service Unavailable"))
|
respx.post(API_URL).mock(return_value=httpx.Response(503, text="Service Unavailable"))
|
||||||
with pytest.raises(ToolError, match="HTTP error 503"):
|
with pytest.raises(ToolError, match="Unraid API returned HTTP 503"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_network_connection_error(self) -> None:
|
async def test_network_connection_error(self) -> None:
|
||||||
respx.post(API_URL).mock(side_effect=httpx.ConnectError("Connection refused"))
|
respx.post(API_URL).mock(side_effect=httpx.ConnectError("Connection refused"))
|
||||||
with pytest.raises(ToolError, match="Network connection error"):
|
with pytest.raises(ToolError, match="Network error connecting to Unraid API"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_network_timeout_error(self) -> None:
|
async def test_network_timeout_error(self) -> None:
|
||||||
respx.post(API_URL).mock(side_effect=httpx.ReadTimeout("Read timed out"))
|
respx.post(API_URL).mock(side_effect=httpx.ReadTimeout("Read timed out"))
|
||||||
with pytest.raises(ToolError, match="Network connection error"):
|
with pytest.raises(ToolError, match="Network error connecting to Unraid API"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_invalid_json_response(self) -> None:
|
async def test_invalid_json_response(self) -> None:
|
||||||
respx.post(API_URL).mock(return_value=httpx.Response(200, text="not json"))
|
respx.post(API_URL).mock(return_value=httpx.Response(200, text="not json"))
|
||||||
with pytest.raises(ToolError, match="Invalid JSON response"):
|
with pytest.raises(ToolError, match=r"invalid response.*not valid JSON"):
|
||||||
await make_graphql_request("query { online }")
|
await make_graphql_request("query { online }")
|
||||||
|
|
||||||
|
|
||||||
@@ -582,7 +583,7 @@ class TestVMToolRequests:
|
|||||||
return_value=_graphql_response({"vm": {"stop": True}})
|
return_value=_graphql_response({"vm": {"stop": True}})
|
||||||
)
|
)
|
||||||
tool = self._get_tool()
|
tool = self._get_tool()
|
||||||
result = await tool(action="stop", vm_id="vm-456")
|
await tool(action="stop", vm_id="vm-456")
|
||||||
body = _extract_request_body(route.calls.last.request)
|
body = _extract_request_body(route.calls.last.request)
|
||||||
assert "StopVM" in body["query"]
|
assert "StopVM" in body["query"]
|
||||||
assert body["variables"] == {"id": "vm-456"}
|
assert body["variables"] == {"id": "vm-456"}
|
||||||
@@ -868,14 +869,14 @@ class TestNotificationsToolRequests:
|
|||||||
title="Test",
|
title="Test",
|
||||||
subject="Sub",
|
subject="Sub",
|
||||||
description="Desc",
|
description="Desc",
|
||||||
importance="info",
|
importance="normal",
|
||||||
)
|
)
|
||||||
body = _extract_request_body(route.calls.last.request)
|
body = _extract_request_body(route.calls.last.request)
|
||||||
assert "CreateNotification" in body["query"]
|
assert "CreateNotification" in body["query"]
|
||||||
inp = body["variables"]["input"]
|
inp = body["variables"]["input"]
|
||||||
assert inp["title"] == "Test"
|
assert inp["title"] == "Test"
|
||||||
assert inp["subject"] == "Sub"
|
assert inp["subject"] == "Sub"
|
||||||
assert inp["importance"] == "INFO" # uppercased
|
assert inp["importance"] == "NORMAL" # uppercased from "normal"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_archive_sends_id_variable(self) -> None:
|
async def test_archive_sends_id_variable(self) -> None:
|
||||||
@@ -1256,7 +1257,7 @@ class TestCrossCuttingConcerns:
|
|||||||
tool = make_tool_fn(
|
tool = make_tool_fn(
|
||||||
"unraid_mcp.tools.info", "register_info_tool", "unraid_info"
|
"unraid_mcp.tools.info", "register_info_tool", "unraid_info"
|
||||||
)
|
)
|
||||||
with pytest.raises(ToolError, match="HTTP error 500"):
|
with pytest.raises(ToolError, match="Unraid API returned HTTP 500"):
|
||||||
await tool(action="online")
|
await tool(action="online")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
@@ -1268,7 +1269,7 @@ class TestCrossCuttingConcerns:
|
|||||||
tool = make_tool_fn(
|
tool = make_tool_fn(
|
||||||
"unraid_mcp.tools.info", "register_info_tool", "unraid_info"
|
"unraid_mcp.tools.info", "register_info_tool", "unraid_info"
|
||||||
)
|
)
|
||||||
with pytest.raises(ToolError, match="Network connection error"):
|
with pytest.raises(ToolError, match="Network error connecting to Unraid API"):
|
||||||
await tool(action="online")
|
await tool(action="online")
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ data management without requiring a live Unraid server.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
@@ -16,6 +16,7 @@ import websockets.exceptions
|
|||||||
|
|
||||||
from unraid_mcp.subscriptions.manager import SubscriptionManager
|
from unraid_mcp.subscriptions.manager import SubscriptionManager
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
@@ -83,7 +84,7 @@ SAMPLE_QUERY = "subscription { test { value } }"
|
|||||||
|
|
||||||
# Shared patch targets
|
# Shared patch targets
|
||||||
_WS_CONNECT = "unraid_mcp.subscriptions.manager.websockets.connect"
|
_WS_CONNECT = "unraid_mcp.subscriptions.manager.websockets.connect"
|
||||||
_API_URL = "unraid_mcp.subscriptions.manager.UNRAID_API_URL"
|
_API_URL = "unraid_mcp.subscriptions.utils.UNRAID_API_URL"
|
||||||
_API_KEY = "unraid_mcp.subscriptions.manager.UNRAID_API_KEY"
|
_API_KEY = "unraid_mcp.subscriptions.manager.UNRAID_API_KEY"
|
||||||
_SSL_CTX = "unraid_mcp.subscriptions.manager.build_ws_ssl_context"
|
_SSL_CTX = "unraid_mcp.subscriptions.manager.build_ws_ssl_context"
|
||||||
_SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep"
|
_SLEEP = "unraid_mcp.subscriptions.manager.asyncio.sleep"
|
||||||
@@ -100,7 +101,7 @@ class TestSubscriptionManagerInit:
|
|||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
assert mgr.active_subscriptions == {}
|
assert mgr.active_subscriptions == {}
|
||||||
assert mgr.resource_data == {}
|
assert mgr.resource_data == {}
|
||||||
assert mgr.websocket is None
|
assert not hasattr(mgr, "websocket")
|
||||||
|
|
||||||
def test_default_auto_start_enabled(self) -> None:
|
def test_default_auto_start_enabled(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
@@ -720,20 +721,20 @@ class TestWebSocketURLConstruction:
|
|||||||
|
|
||||||
class TestResourceData:
|
class TestResourceData:
|
||||||
|
|
||||||
def test_get_resource_data_returns_none_when_empty(self) -> None:
|
async def test_get_resource_data_returns_none_when_empty(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
assert mgr.get_resource_data("nonexistent") is None
|
assert await mgr.get_resource_data("nonexistent") is None
|
||||||
|
|
||||||
def test_get_resource_data_returns_stored_data(self) -> None:
|
async def test_get_resource_data_returns_stored_data(self) -> None:
|
||||||
from unraid_mcp.core.types import SubscriptionData
|
from unraid_mcp.core.types import SubscriptionData
|
||||||
|
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
mgr.resource_data["test"] = SubscriptionData(
|
mgr.resource_data["test"] = SubscriptionData(
|
||||||
data={"key": "value"},
|
data={"key": "value"},
|
||||||
last_updated=datetime.now(),
|
last_updated=datetime.now(UTC),
|
||||||
subscription_type="test",
|
subscription_type="test",
|
||||||
)
|
)
|
||||||
result = mgr.get_resource_data("test")
|
result = await mgr.get_resource_data("test")
|
||||||
assert result == {"key": "value"}
|
assert result == {"key": "value"}
|
||||||
|
|
||||||
def test_list_active_subscriptions_empty(self) -> None:
|
def test_list_active_subscriptions_empty(self) -> None:
|
||||||
@@ -755,46 +756,46 @@ class TestResourceData:
|
|||||||
|
|
||||||
class TestSubscriptionStatus:
|
class TestSubscriptionStatus:
|
||||||
|
|
||||||
def test_status_includes_all_configured_subscriptions(self) -> None:
|
async def test_status_includes_all_configured_subscriptions(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
status = mgr.get_subscription_status()
|
status = await mgr.get_subscription_status()
|
||||||
for name in mgr.subscription_configs:
|
for name in mgr.subscription_configs:
|
||||||
assert name in status
|
assert name in status
|
||||||
|
|
||||||
def test_status_default_connection_state(self) -> None:
|
async def test_status_default_connection_state(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
status = mgr.get_subscription_status()
|
status = await mgr.get_subscription_status()
|
||||||
for sub_status in status.values():
|
for sub_status in status.values():
|
||||||
assert sub_status["runtime"]["connection_state"] == "not_started"
|
assert sub_status["runtime"]["connection_state"] == "not_started"
|
||||||
|
|
||||||
def test_status_shows_active_flag(self) -> None:
|
async def test_status_shows_active_flag(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
mgr.active_subscriptions["logFileSubscription"] = MagicMock()
|
mgr.active_subscriptions["logFileSubscription"] = MagicMock()
|
||||||
status = mgr.get_subscription_status()
|
status = await mgr.get_subscription_status()
|
||||||
assert status["logFileSubscription"]["runtime"]["active"] is True
|
assert status["logFileSubscription"]["runtime"]["active"] is True
|
||||||
|
|
||||||
def test_status_shows_data_availability(self) -> None:
|
async def test_status_shows_data_availability(self) -> None:
|
||||||
from unraid_mcp.core.types import SubscriptionData
|
from unraid_mcp.core.types import SubscriptionData
|
||||||
|
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
mgr.resource_data["logFileSubscription"] = SubscriptionData(
|
mgr.resource_data["logFileSubscription"] = SubscriptionData(
|
||||||
data={"log": "content"},
|
data={"log": "content"},
|
||||||
last_updated=datetime.now(),
|
last_updated=datetime.now(UTC),
|
||||||
subscription_type="logFileSubscription",
|
subscription_type="logFileSubscription",
|
||||||
)
|
)
|
||||||
status = mgr.get_subscription_status()
|
status = await mgr.get_subscription_status()
|
||||||
assert status["logFileSubscription"]["data"]["available"] is True
|
assert status["logFileSubscription"]["data"]["available"] is True
|
||||||
|
|
||||||
def test_status_shows_error_info(self) -> None:
|
async def test_status_shows_error_info(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
mgr.last_error["logFileSubscription"] = "Test error message"
|
mgr.last_error["logFileSubscription"] = "Test error message"
|
||||||
status = mgr.get_subscription_status()
|
status = await mgr.get_subscription_status()
|
||||||
assert status["logFileSubscription"]["runtime"]["last_error"] == "Test error message"
|
assert status["logFileSubscription"]["runtime"]["last_error"] == "Test error message"
|
||||||
|
|
||||||
def test_status_reconnect_attempts_tracked(self) -> None:
|
async def test_status_reconnect_attempts_tracked(self) -> None:
|
||||||
mgr = SubscriptionManager()
|
mgr = SubscriptionManager()
|
||||||
mgr.reconnect_attempts["logFileSubscription"] = 3
|
mgr.reconnect_attempts["logFileSubscription"] = 3
|
||||||
status = mgr.get_subscription_status()
|
status = await mgr.get_subscription_status()
|
||||||
assert status["logFileSubscription"]["runtime"]["reconnect_attempts"] == 3
|
assert status["logFileSubscription"]["runtime"]["reconnect_attempts"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,12 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
# conftest.py is the shared test-helper module for this project.
|
||||||
|
# pytest automatically adds tests/ to sys.path, making it importable here
|
||||||
|
# without a package __init__.py. Do NOT add tests/__init__.py — it breaks
|
||||||
|
# conftest.py's fixture auto-discovery.
|
||||||
|
from conftest import make_tool_fn
|
||||||
|
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
|
||||||
# Import DESTRUCTIVE_ACTIONS sets from every tool module that defines one
|
# Import DESTRUCTIVE_ACTIONS sets from every tool module that defines one
|
||||||
@@ -24,10 +30,6 @@ from unraid_mcp.tools.rclone import MUTATIONS as RCLONE_MUTATIONS
|
|||||||
from unraid_mcp.tools.virtualization import DESTRUCTIVE_ACTIONS as VM_DESTRUCTIVE
|
from unraid_mcp.tools.virtualization import DESTRUCTIVE_ACTIONS as VM_DESTRUCTIVE
|
||||||
from unraid_mcp.tools.virtualization import MUTATIONS as VM_MUTATIONS
|
from unraid_mcp.tools.virtualization import MUTATIONS as VM_MUTATIONS
|
||||||
|
|
||||||
# Centralized import for make_tool_fn helper
|
|
||||||
# conftest.py sits in tests/ and is importable without __init__.py
|
|
||||||
from conftest import make_tool_fn
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Known destructive actions registry (ground truth for this audit)
|
# Known destructive actions registry (ground truth for this audit)
|
||||||
@@ -39,7 +41,7 @@ KNOWN_DESTRUCTIVE: dict[str, dict[str, set[str]]] = {
|
|||||||
"module": "unraid_mcp.tools.docker",
|
"module": "unraid_mcp.tools.docker",
|
||||||
"register_fn": "register_docker_tool",
|
"register_fn": "register_docker_tool",
|
||||||
"tool_name": "unraid_docker",
|
"tool_name": "unraid_docker",
|
||||||
"actions": {"remove"},
|
"actions": {"remove", "update_all"},
|
||||||
"runtime_set": DOCKER_DESTRUCTIVE,
|
"runtime_set": DOCKER_DESTRUCTIVE,
|
||||||
},
|
},
|
||||||
"vm": {
|
"vm": {
|
||||||
@@ -126,9 +128,11 @@ class TestDestructiveActionRegistries:
|
|||||||
missing: list[str] = []
|
missing: list[str] = []
|
||||||
for tool_key, mutations in all_mutations.items():
|
for tool_key, mutations in all_mutations.items():
|
||||||
destructive = all_destructive[tool_key]
|
destructive = all_destructive[tool_key]
|
||||||
for action_name in mutations:
|
missing.extend(
|
||||||
if ("delete" in action_name or "remove" in action_name) and action_name not in destructive:
|
f"{tool_key}/{action_name}"
|
||||||
missing.append(f"{tool_key}/{action_name}")
|
for action_name in mutations
|
||||||
|
if ("delete" in action_name or "remove" in action_name) and action_name not in destructive
|
||||||
|
)
|
||||||
assert not missing, (
|
assert not missing, (
|
||||||
f"Mutations with 'delete'/'remove' not in DESTRUCTIVE_ACTIONS: {missing}"
|
f"Mutations with 'delete'/'remove' not in DESTRUCTIVE_ACTIONS: {missing}"
|
||||||
)
|
)
|
||||||
@@ -143,6 +147,7 @@ class TestDestructiveActionRegistries:
|
|||||||
_DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [
|
_DESTRUCTIVE_TEST_CASES: list[tuple[str, str, dict]] = [
|
||||||
# Docker
|
# Docker
|
||||||
("docker", "remove", {"container_id": "abc123"}),
|
("docker", "remove", {"container_id": "abc123"}),
|
||||||
|
("docker", "update_all", {}),
|
||||||
# VM
|
# VM
|
||||||
("vm", "force_stop", {"vm_id": "test-vm-uuid"}),
|
("vm", "force_stop", {"vm_id": "test-vm-uuid"}),
|
||||||
("vm", "reset", {"vm_id": "test-vm-uuid"}),
|
("vm", "reset", {"vm_id": "test-vm-uuid"}),
|
||||||
@@ -268,6 +273,15 @@ class TestConfirmationGuards:
|
|||||||
class TestConfirmAllowsExecution:
|
class TestConfirmAllowsExecution:
|
||||||
"""Destructive actions with confirm=True should reach the GraphQL layer."""
|
"""Destructive actions with confirm=True should reach the GraphQL layer."""
|
||||||
|
|
||||||
|
async def test_docker_update_all_with_confirm(self, _mock_docker_graphql: AsyncMock) -> None:
|
||||||
|
_mock_docker_graphql.return_value = {
|
||||||
|
"docker": {"updateAllContainers": [{"id": "c1", "names": ["app"], "state": "running", "status": "Up"}]}
|
||||||
|
}
|
||||||
|
tool_fn = make_tool_fn("unraid_mcp.tools.docker", "register_docker_tool", "unraid_docker")
|
||||||
|
result = await tool_fn(action="update_all", confirm=True)
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["action"] == "update_all"
|
||||||
|
|
||||||
async def test_docker_remove_with_confirm(self, _mock_docker_graphql: AsyncMock) -> None:
|
async def test_docker_remove_with_confirm(self, _mock_docker_graphql: AsyncMock) -> None:
|
||||||
cid = "a" * 64 + ":local"
|
cid = "a" * 64 + ":local"
|
||||||
_mock_docker_graphql.side_effect = [
|
_mock_docker_graphql.side_effect = [
|
||||||
|
|||||||
@@ -384,10 +384,16 @@ class TestVmQueries:
|
|||||||
errors = _validate_operation(schema, QUERIES["list"])
|
errors = _validate_operation(schema, QUERIES["list"])
|
||||||
assert not errors, f"list query validation failed: {errors}"
|
assert not errors, f"list query validation failed: {errors}"
|
||||||
|
|
||||||
|
def test_details_query(self, schema: GraphQLSchema) -> None:
|
||||||
|
from unraid_mcp.tools.virtualization import QUERIES
|
||||||
|
|
||||||
|
errors = _validate_operation(schema, QUERIES["details"])
|
||||||
|
assert not errors, f"details query validation failed: {errors}"
|
||||||
|
|
||||||
def test_all_vm_queries_covered(self, schema: GraphQLSchema) -> None:
|
def test_all_vm_queries_covered(self, schema: GraphQLSchema) -> None:
|
||||||
from unraid_mcp.tools.virtualization import QUERIES
|
from unraid_mcp.tools.virtualization import QUERIES
|
||||||
|
|
||||||
assert set(QUERIES.keys()) == {"list"}
|
assert set(QUERIES.keys()) == {"list", "details"}
|
||||||
|
|
||||||
|
|
||||||
class TestVmMutations:
|
class TestVmMutations:
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class TestArrayActions:
|
|||||||
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("disk error")
|
_mock_graphql.side_effect = RuntimeError("disk error")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="disk error"):
|
with pytest.raises(ToolError, match="Failed to execute array/parity_status"):
|
||||||
await tool_fn(action="parity_status")
|
await tool_fn(action="parity_status")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for unraid_mcp.core.client — GraphQL client infrastructure."""
|
"""Tests for unraid_mcp.core.client — GraphQL client infrastructure."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -9,9 +10,11 @@ import pytest
|
|||||||
from unraid_mcp.core.client import (
|
from unraid_mcp.core.client import (
|
||||||
DEFAULT_TIMEOUT,
|
DEFAULT_TIMEOUT,
|
||||||
DISK_TIMEOUT,
|
DISK_TIMEOUT,
|
||||||
_redact_sensitive,
|
_QueryCache,
|
||||||
|
_RateLimiter,
|
||||||
is_idempotent_error,
|
is_idempotent_error,
|
||||||
make_graphql_request,
|
make_graphql_request,
|
||||||
|
redact_sensitive,
|
||||||
)
|
)
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
|
||||||
@@ -57,7 +60,7 @@ class TestIsIdempotentError:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _redact_sensitive
|
# redact_sensitive
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -66,36 +69,36 @@ class TestRedactSensitive:
|
|||||||
|
|
||||||
def test_flat_dict(self) -> None:
|
def test_flat_dict(self) -> None:
|
||||||
data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"}
|
data = {"username": "admin", "password": "hunter2", "host": "10.0.0.1"}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["username"] == "admin"
|
assert result["username"] == "admin"
|
||||||
assert result["password"] == "***"
|
assert result["password"] == "***"
|
||||||
assert result["host"] == "10.0.0.1"
|
assert result["host"] == "10.0.0.1"
|
||||||
|
|
||||||
def test_nested_dict(self) -> None:
|
def test_nested_dict(self) -> None:
|
||||||
data = {"config": {"apiKey": "abc123", "url": "http://host"}}
|
data = {"config": {"apiKey": "abc123", "url": "http://host"}}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["config"]["apiKey"] == "***"
|
assert result["config"]["apiKey"] == "***"
|
||||||
assert result["config"]["url"] == "http://host"
|
assert result["config"]["url"] == "http://host"
|
||||||
|
|
||||||
def test_list_of_dicts(self) -> None:
|
def test_list_of_dicts(self) -> None:
|
||||||
data = [{"token": "t1"}, {"name": "safe"}]
|
data = [{"token": "t1"}, {"name": "safe"}]
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result[0]["token"] == "***"
|
assert result[0]["token"] == "***"
|
||||||
assert result[1]["name"] == "safe"
|
assert result[1]["name"] == "safe"
|
||||||
|
|
||||||
def test_deeply_nested(self) -> None:
|
def test_deeply_nested(self) -> None:
|
||||||
data = {"a": {"b": {"c": {"secret": "deep"}}}}
|
data = {"a": {"b": {"c": {"secret": "deep"}}}}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["a"]["b"]["c"]["secret"] == "***"
|
assert result["a"]["b"]["c"]["secret"] == "***"
|
||||||
|
|
||||||
def test_non_dict_passthrough(self) -> None:
|
def test_non_dict_passthrough(self) -> None:
|
||||||
assert _redact_sensitive("plain_string") == "plain_string"
|
assert redact_sensitive("plain_string") == "plain_string"
|
||||||
assert _redact_sensitive(42) == 42
|
assert redact_sensitive(42) == 42
|
||||||
assert _redact_sensitive(None) is None
|
assert redact_sensitive(None) is None
|
||||||
|
|
||||||
def test_case_insensitive_keys(self) -> None:
|
def test_case_insensitive_keys(self) -> None:
|
||||||
data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"}
|
data = {"Password": "p1", "TOKEN": "t1", "ApiKey": "k1", "Secret": "s1", "Key": "x1"}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
for v in result.values():
|
for v in result.values():
|
||||||
assert v == "***"
|
assert v == "***"
|
||||||
|
|
||||||
@@ -109,7 +112,7 @@ class TestRedactSensitive:
|
|||||||
"username": "safe",
|
"username": "safe",
|
||||||
"host": "safe",
|
"host": "safe",
|
||||||
}
|
}
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result["user_password"] == "***"
|
assert result["user_password"] == "***"
|
||||||
assert result["api_key_value"] == "***"
|
assert result["api_key_value"] == "***"
|
||||||
assert result["auth_token_expiry"] == "***"
|
assert result["auth_token_expiry"] == "***"
|
||||||
@@ -119,12 +122,26 @@ class TestRedactSensitive:
|
|||||||
|
|
||||||
def test_mixed_list_content(self) -> None:
|
def test_mixed_list_content(self) -> None:
|
||||||
data = [{"key": "val"}, "string", 123, [{"token": "inner"}]]
|
data = [{"key": "val"}, "string", 123, [{"token": "inner"}]]
|
||||||
result = _redact_sensitive(data)
|
result = redact_sensitive(data)
|
||||||
assert result[0]["key"] == "***"
|
assert result[0]["key"] == "***"
|
||||||
assert result[1] == "string"
|
assert result[1] == "string"
|
||||||
assert result[2] == 123
|
assert result[2] == 123
|
||||||
assert result[3][0]["token"] == "***"
|
assert result[3][0]["token"] == "***"
|
||||||
|
|
||||||
|
def test_new_sensitive_keys_are_redacted(self) -> None:
|
||||||
|
"""PR-added keys: authorization, cookie, session, credential, passphrase, jwt."""
|
||||||
|
data = {
|
||||||
|
"authorization": "Bearer token123",
|
||||||
|
"cookie": "session=abc",
|
||||||
|
"jwt": "eyJ...",
|
||||||
|
"credential": "secret_cred",
|
||||||
|
"passphrase": "hunter2",
|
||||||
|
"session": "sess_id",
|
||||||
|
}
|
||||||
|
result = redact_sensitive(data)
|
||||||
|
for key, val in result.items():
|
||||||
|
assert val == "***", f"Key '{key}' was not redacted"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Timeout constants
|
# Timeout constants
|
||||||
@@ -274,7 +291,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="HTTP error 401"),
|
pytest.raises(ToolError, match="Unraid API returned HTTP 401"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -292,7 +309,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="HTTP error 500"),
|
pytest.raises(ToolError, match="Unraid API returned HTTP 500"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -310,7 +327,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="HTTP error 503"),
|
pytest.raises(ToolError, match="Unraid API returned HTTP 503"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -320,7 +337,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="Network connection error"),
|
pytest.raises(ToolError, match="Network error connecting to Unraid API"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -330,7 +347,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="Network connection error"),
|
pytest.raises(ToolError, match="Network error connecting to Unraid API"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -344,7 +361,7 @@ class TestMakeGraphQLRequestErrors:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
pytest.raises(ToolError, match="Invalid JSON response"),
|
pytest.raises(ToolError, match=r"invalid response.*not valid JSON"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
@@ -464,3 +481,240 @@ class TestGraphQLErrorHandling:
|
|||||||
pytest.raises(ToolError, match="GraphQL API error"),
|
pytest.raises(ToolError, match="GraphQL API error"),
|
||||||
):
|
):
|
||||||
await make_graphql_request("{ info }")
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _RateLimiter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""Unit tests for the token bucket rate limiter."""
|
||||||
|
|
||||||
|
async def test_acquire_consumes_one_token(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
|
||||||
|
initial = limiter.tokens
|
||||||
|
await limiter.acquire()
|
||||||
|
assert limiter.tokens == pytest.approx(initial - 1, abs=1e-3)
|
||||||
|
|
||||||
|
async def test_acquire_succeeds_when_tokens_available(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=5, refill_rate=1.0)
|
||||||
|
# Should complete without sleeping
|
||||||
|
for _ in range(5):
|
||||||
|
await limiter.acquire()
|
||||||
|
# _refill() runs during each acquire() call and adds a tiny time-based
|
||||||
|
# amount; check < 1.0 (not enough for another immediate request) rather
|
||||||
|
# than == 0.0 to avoid flakiness from timing.
|
||||||
|
assert limiter.tokens < 1.0
|
||||||
|
|
||||||
|
async def test_tokens_do_not_exceed_max(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=10, refill_rate=1.0)
|
||||||
|
# Force refill with large elapsed time
|
||||||
|
limiter.last_refill = time.monotonic() - 100.0 # 100 seconds ago
|
||||||
|
limiter._refill()
|
||||||
|
assert limiter.tokens == 10.0 # Capped at max_tokens
|
||||||
|
|
||||||
|
async def test_refill_adds_tokens_based_on_elapsed(self) -> None:
|
||||||
|
limiter = _RateLimiter(max_tokens=100, refill_rate=10.0)
|
||||||
|
limiter.tokens = 0.0
|
||||||
|
limiter.last_refill = time.monotonic() - 1.0 # 1 second ago
|
||||||
|
limiter._refill()
|
||||||
|
# Should have refilled ~10 tokens (10.0 rate * 1.0 sec)
|
||||||
|
assert 9.5 < limiter.tokens < 10.5
|
||||||
|
|
||||||
|
async def test_acquire_sleeps_when_no_tokens(self) -> None:
|
||||||
|
"""When tokens are exhausted, acquire should sleep before consuming."""
|
||||||
|
limiter = _RateLimiter(max_tokens=1, refill_rate=1.0)
|
||||||
|
limiter.tokens = 0.0
|
||||||
|
|
||||||
|
sleep_calls = []
|
||||||
|
|
||||||
|
async def fake_sleep(duration: float) -> None:
|
||||||
|
sleep_calls.append(duration)
|
||||||
|
# Simulate refill by advancing last_refill so tokens replenish
|
||||||
|
limiter.tokens = 1.0
|
||||||
|
limiter.last_refill = time.monotonic()
|
||||||
|
|
||||||
|
with patch("unraid_mcp.core.client.asyncio.sleep", side_effect=fake_sleep):
|
||||||
|
await limiter.acquire()
|
||||||
|
|
||||||
|
assert len(sleep_calls) == 1
|
||||||
|
assert sleep_calls[0] > 0
|
||||||
|
|
||||||
|
async def test_default_params_match_api_limits(self) -> None:
|
||||||
|
"""Default rate limiter must use 90 tokens at 9.0/sec (10% headroom from 100/10s)."""
|
||||||
|
limiter = _RateLimiter()
|
||||||
|
assert limiter.max_tokens == 90
|
||||||
|
assert limiter.refill_rate == 9.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _QueryCache
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryCache:
|
||||||
|
"""Unit tests for the TTL query cache."""
|
||||||
|
|
||||||
|
def test_miss_on_empty_cache(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
assert cache.get("{ info }", None) is None
|
||||||
|
|
||||||
|
def test_put_and_get_hit(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
data = {"result": "ok"}
|
||||||
|
cache.put("GetNetworkConfig { }", None, data)
|
||||||
|
result = cache.get("GetNetworkConfig { }", None)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_expired_entry_returns_none(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
data = {"result": "ok"}
|
||||||
|
cache.put("GetNetworkConfig { }", None, data)
|
||||||
|
# Manually expire the entry
|
||||||
|
key = cache._cache_key("GetNetworkConfig { }", None)
|
||||||
|
cache._store[key] = (time.monotonic() - 1.0, data) # expired 1 sec ago
|
||||||
|
assert cache.get("GetNetworkConfig { }", None) is None
|
||||||
|
|
||||||
|
def test_invalidate_all_clears_store(self) -> None:
|
||||||
|
cache = _QueryCache()
|
||||||
|
cache.put("GetNetworkConfig { }", None, {"x": 1})
|
||||||
|
cache.put("GetOwner { }", None, {"y": 2})
|
||||||
|
assert len(cache._store) == 2
|
||||||
|
cache.invalidate_all()
|
||||||
|
assert len(cache._store) == 0
|
||||||
|
|
||||||
|
def test_variables_affect_cache_key(self) -> None:
|
||||||
|
"""Different variables produce different cache keys."""
|
||||||
|
cache = _QueryCache()
|
||||||
|
q = "GetNetworkConfig($id: ID!) { network(id: $id) { name } }"
|
||||||
|
cache.put(q, {"id": "1"}, {"name": "eth0"})
|
||||||
|
cache.put(q, {"id": "2"}, {"name": "eth1"})
|
||||||
|
assert cache.get(q, {"id": "1"}) == {"name": "eth0"}
|
||||||
|
assert cache.get(q, {"id": "2"}) == {"name": "eth1"}
|
||||||
|
|
||||||
|
def test_is_cacheable_returns_true_for_known_prefixes(self) -> None:
|
||||||
|
assert _QueryCache.is_cacheable("GetNetworkConfig { ... }") is True
|
||||||
|
assert _QueryCache.is_cacheable("GetRegistrationInfo { ... }") is True
|
||||||
|
assert _QueryCache.is_cacheable("GetOwner { ... }") is True
|
||||||
|
assert _QueryCache.is_cacheable("GetFlash { ... }") is True
|
||||||
|
|
||||||
|
def test_is_cacheable_returns_false_for_mutations(self) -> None:
|
||||||
|
assert _QueryCache.is_cacheable('mutation { docker { start(id: "x") } }') is False
|
||||||
|
|
||||||
|
def test_is_cacheable_returns_false_for_unlisted_queries(self) -> None:
|
||||||
|
assert _QueryCache.is_cacheable("{ docker { containers { id } } }") is False
|
||||||
|
assert _QueryCache.is_cacheable("{ info { os } }") is False
|
||||||
|
|
||||||
|
def test_is_cacheable_mutation_check_is_prefix(self) -> None:
|
||||||
|
"""Queries that start with 'mutation' after whitespace are not cacheable."""
|
||||||
|
assert _QueryCache.is_cacheable(" mutation { ... }") is False
|
||||||
|
|
||||||
|
def test_is_cacheable_with_explicit_query_keyword(self) -> None:
|
||||||
|
"""Operation names after explicit 'query' keyword must be recognized."""
|
||||||
|
assert _QueryCache.is_cacheable("query GetNetworkConfig { network { name } }") is True
|
||||||
|
assert _QueryCache.is_cacheable("query GetOwner { owner { name } }") is True
|
||||||
|
|
||||||
|
def test_is_cacheable_anonymous_query_returns_false(self) -> None:
|
||||||
|
"""Anonymous 'query { ... }' has no operation name — must not be cached."""
|
||||||
|
assert _QueryCache.is_cacheable("query { network { name } }") is False
|
||||||
|
|
||||||
|
def test_expired_entry_removed_from_store(self) -> None:
|
||||||
|
"""Accessing an expired entry should remove it from the internal store."""
|
||||||
|
cache = _QueryCache()
|
||||||
|
cache.put("GetOwner { }", None, {"owner": "root"})
|
||||||
|
key = cache._cache_key("GetOwner { }", None)
|
||||||
|
cache._store[key] = (time.monotonic() - 1.0, {"owner": "root"})
|
||||||
|
assert key in cache._store
|
||||||
|
cache.get("GetOwner { }", None) # triggers deletion
|
||||||
|
assert key not in cache._store
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# make_graphql_request — 429 retry behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitRetry:
|
||||||
|
"""Tests for the 429 retry loop in make_graphql_request."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _patch_config(self):
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.core.client.UNRAID_API_URL", "https://unraid.local/graphql"),
|
||||||
|
patch("unraid_mcp.core.client.UNRAID_API_KEY", "test-key"),
|
||||||
|
patch("unraid_mcp.core.client.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def _make_429_response(self) -> MagicMock:
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 429
|
||||||
|
resp.raise_for_status = MagicMock()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def _make_ok_response(self, data: dict) -> MagicMock:
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 200
|
||||||
|
resp.raise_for_status = MagicMock()
|
||||||
|
resp.json.return_value = {"data": data}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def test_single_429_then_success_retries(self) -> None:
|
||||||
|
"""One 429 followed by a success should return the data."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_ok_response({"info": {"os": "Unraid"}}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client):
|
||||||
|
result = await make_graphql_request("{ info { os } }")
|
||||||
|
|
||||||
|
assert result == {"info": {"os": "Unraid"}}
|
||||||
|
assert mock_client.post.call_count == 2
|
||||||
|
|
||||||
|
async def test_two_429s_then_success(self) -> None:
|
||||||
|
"""Two 429s followed by success returns data after 2 retries."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_ok_response({"x": 1}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("unraid_mcp.core.client.get_http_client", return_value=mock_client):
|
||||||
|
result = await make_graphql_request("{ x }")
|
||||||
|
|
||||||
|
assert result == {"x": 1}
|
||||||
|
assert mock_client.post.call_count == 3
|
||||||
|
|
||||||
|
async def test_three_429s_raises_tool_error(self) -> None:
|
||||||
|
"""Three consecutive 429s (all retries exhausted) raises ToolError."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
|
pytest.raises(ToolError, match="rate limiting"),
|
||||||
|
):
|
||||||
|
await make_graphql_request("{ info }")
|
||||||
|
|
||||||
|
async def test_rate_limit_error_message_advises_wait(self) -> None:
|
||||||
|
"""The ToolError message should tell the user to wait ~10 seconds."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
self._make_429_response(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.core.client.get_http_client", return_value=mock_client),
|
||||||
|
pytest.raises(ToolError, match="10 seconds"),
|
||||||
|
):
|
||||||
|
await make_graphql_request("{ info }")
|
||||||
|
|||||||
@@ -80,6 +80,14 @@ class TestDockerValidation:
|
|||||||
with pytest.raises(ToolError, match="network_id"):
|
with pytest.raises(ToolError, match="network_id"):
|
||||||
await tool_fn(action="network_details")
|
await tool_fn(action="network_details")
|
||||||
|
|
||||||
|
async def test_non_logs_action_ignores_tail_lines_validation(
|
||||||
|
self, _mock_graphql: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
_mock_graphql.return_value = {"docker": {"containers": []}}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(action="list", tail_lines=0)
|
||||||
|
assert result["containers"] == []
|
||||||
|
|
||||||
|
|
||||||
class TestDockerActions:
|
class TestDockerActions:
|
||||||
async def test_list(self, _mock_graphql: AsyncMock) -> None:
|
async def test_list(self, _mock_graphql: AsyncMock) -> None:
|
||||||
@@ -175,7 +183,7 @@ class TestDockerActions:
|
|||||||
"docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]}
|
"docker": {"updateAllContainers": [{"id": "c1", "state": "running"}]}
|
||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="update_all")
|
result = await tool_fn(action="update_all", confirm=True)
|
||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
assert len(result["containers"]) == 1
|
assert len(result["containers"]) == 1
|
||||||
|
|
||||||
@@ -224,9 +232,22 @@ class TestDockerActions:
|
|||||||
async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps_in_tool_error(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("unexpected failure")
|
_mock_graphql.side_effect = RuntimeError("unexpected failure")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="unexpected failure"):
|
with pytest.raises(ToolError, match="Failed to execute docker/list"):
|
||||||
await tool_fn(action="list")
|
await tool_fn(action="list")
|
||||||
|
|
||||||
|
async def test_short_id_prefix_ambiguous_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"docker": {
|
||||||
|
"containers": [
|
||||||
|
{"id": "abcdef1234560000000000000000000000000000000000000000000000000000:local", "names": ["plex"]},
|
||||||
|
{"id": "abcdef1234561111111111111111111111111111111111111111111111111111:local", "names": ["sonarr"]},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="ambiguous"):
|
||||||
|
await tool_fn(action="logs", container_id="abcdef123456")
|
||||||
|
|
||||||
|
|
||||||
class TestDockerMutationFailures:
|
class TestDockerMutationFailures:
|
||||||
"""Tests for mutation responses that indicate failure or unexpected shapes."""
|
"""Tests for mutation responses that indicate failure or unexpected shapes."""
|
||||||
@@ -271,10 +292,16 @@ class TestDockerMutationFailures:
|
|||||||
"""update_all with no containers to update."""
|
"""update_all with no containers to update."""
|
||||||
_mock_graphql.return_value = {"docker": {"updateAllContainers": []}}
|
_mock_graphql.return_value = {"docker": {"updateAllContainers": []}}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="update_all")
|
result = await tool_fn(action="update_all", confirm=True)
|
||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
assert result["containers"] == []
|
assert result["containers"] == []
|
||||||
|
|
||||||
|
async def test_update_all_requires_confirm(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
"""update_all is destructive and requires confirm=True."""
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="destructive"):
|
||||||
|
await tool_fn(action="update_all")
|
||||||
|
|
||||||
async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None:
|
async def test_mutation_timeout(self, _mock_graphql: AsyncMock) -> None:
|
||||||
"""Mid-operation timeout during a docker mutation."""
|
"""Mid-operation timeout during a docker mutation."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import pytest
|
|||||||
from conftest import make_tool_fn
|
from conftest import make_tool_fn
|
||||||
|
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
from unraid_mcp.core.utils import safe_display_url
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -99,7 +100,7 @@ class TestHealthActions:
|
|||||||
"unraid_mcp.tools.health._diagnose_subscriptions",
|
"unraid_mcp.tools.health._diagnose_subscriptions",
|
||||||
side_effect=RuntimeError("broken"),
|
side_effect=RuntimeError("broken"),
|
||||||
),
|
),
|
||||||
pytest.raises(ToolError, match="broken"),
|
pytest.raises(ToolError, match="Failed to execute health/diagnose"),
|
||||||
):
|
):
|
||||||
await tool_fn(action="diagnose")
|
await tool_fn(action="diagnose")
|
||||||
|
|
||||||
@@ -114,7 +115,7 @@ class TestHealthActions:
|
|||||||
assert "cpu_sub" in result
|
assert "cpu_sub" in result
|
||||||
|
|
||||||
async def test_diagnose_import_error_internal(self) -> None:
|
async def test_diagnose_import_error_internal(self) -> None:
|
||||||
"""_diagnose_subscriptions catches ImportError and returns error dict."""
|
"""_diagnose_subscriptions raises ToolError when subscription modules are unavailable."""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from unraid_mcp.tools.health import _diagnose_subscriptions
|
from unraid_mcp.tools.health import _diagnose_subscriptions
|
||||||
@@ -126,16 +127,70 @@ class TestHealthActions:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Replace the modules with objects that raise ImportError on access
|
# Replace the modules with objects that raise ImportError on access
|
||||||
with patch.dict(
|
with (
|
||||||
sys.modules,
|
patch.dict(
|
||||||
{
|
sys.modules,
|
||||||
"unraid_mcp.subscriptions": None,
|
{
|
||||||
"unraid_mcp.subscriptions.manager": None,
|
"unraid_mcp.subscriptions": None,
|
||||||
"unraid_mcp.subscriptions.resources": None,
|
"unraid_mcp.subscriptions.manager": None,
|
||||||
},
|
"unraid_mcp.subscriptions.resources": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
pytest.raises(ToolError, match="Subscription modules not available"),
|
||||||
):
|
):
|
||||||
result = await _diagnose_subscriptions()
|
await _diagnose_subscriptions()
|
||||||
assert "error" in result
|
|
||||||
finally:
|
finally:
|
||||||
# Restore cached modules
|
# Restore cached modules
|
||||||
sys.modules.update(cached)
|
sys.modules.update(cached)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _safe_display_url — URL redaction helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeDisplayUrl:
|
||||||
|
"""Verify that safe_display_url strips credentials/path and preserves scheme+host+port."""
|
||||||
|
|
||||||
|
def test_none_returns_none(self) -> None:
|
||||||
|
assert safe_display_url(None) is None
|
||||||
|
|
||||||
|
def test_empty_string_returns_none(self) -> None:
|
||||||
|
assert safe_display_url("") is None
|
||||||
|
|
||||||
|
def test_simple_url_scheme_and_host(self) -> None:
|
||||||
|
assert safe_display_url("https://unraid.local/graphql") == "https://unraid.local"
|
||||||
|
|
||||||
|
def test_preserves_port(self) -> None:
|
||||||
|
assert safe_display_url("https://10.1.0.2:31337/api/graphql") == "https://10.1.0.2:31337"
|
||||||
|
|
||||||
|
def test_strips_path(self) -> None:
|
||||||
|
result = safe_display_url("http://unraid.local/some/deep/path?query=1")
|
||||||
|
assert "path" not in result
|
||||||
|
assert "query" not in result
|
||||||
|
|
||||||
|
def test_strips_credentials(self) -> None:
|
||||||
|
result = safe_display_url("https://user:password@unraid.local/graphql")
|
||||||
|
assert "user" not in result
|
||||||
|
assert "password" not in result
|
||||||
|
assert result == "https://unraid.local"
|
||||||
|
|
||||||
|
def test_strips_query_params(self) -> None:
|
||||||
|
result = safe_display_url("http://host.local?token=abc&key=xyz")
|
||||||
|
assert "token" not in result
|
||||||
|
assert "abc" not in result
|
||||||
|
|
||||||
|
def test_http_scheme_preserved(self) -> None:
|
||||||
|
result = safe_display_url("http://10.0.0.1:8080/api")
|
||||||
|
assert result == "http://10.0.0.1:8080"
|
||||||
|
|
||||||
|
def test_tailscale_url(self) -> None:
|
||||||
|
result = safe_display_url("https://100.118.209.1:31337/graphql")
|
||||||
|
assert result == "https://100.118.209.1:31337"
|
||||||
|
|
||||||
|
def test_malformed_ipv6_url_returns_unparseable(self) -> None:
|
||||||
|
"""Malformed IPv6 brackets in netloc cause urlparse.hostname to raise ValueError."""
|
||||||
|
# urlparse("https://[invalid") parses without error, but accessing .hostname
|
||||||
|
# raises ValueError: Invalid IPv6 URL — this triggers the except branch.
|
||||||
|
result = safe_display_url("https://[invalid")
|
||||||
|
assert result == "<unparseable>"
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class TestUnraidInfoTool:
|
|||||||
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("unexpected")
|
_mock_graphql.side_effect = RuntimeError("unexpected")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="unexpected"):
|
with pytest.raises(ToolError, match="Failed to execute info/online"):
|
||||||
await tool_fn(action="online")
|
await tool_fn(action="online")
|
||||||
|
|
||||||
async def test_metrics(self, _mock_graphql: AsyncMock) -> None:
|
async def test_metrics(self, _mock_graphql: AsyncMock) -> None:
|
||||||
@@ -201,6 +201,7 @@ class TestUnraidInfoTool:
|
|||||||
_mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]}
|
_mock_graphql.return_value = {"services": [{"name": "docker", "state": "running"}]}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="services")
|
result = await tool_fn(action="services")
|
||||||
|
assert "services" in result
|
||||||
assert len(result["services"]) == 1
|
assert len(result["services"]) == 1
|
||||||
assert result["services"][0]["name"] == "docker"
|
assert result["services"][0]["name"] == "docker"
|
||||||
|
|
||||||
@@ -225,6 +226,7 @@ class TestUnraidInfoTool:
|
|||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="servers")
|
result = await tool_fn(action="servers")
|
||||||
|
assert "servers" in result
|
||||||
assert len(result["servers"]) == 1
|
assert len(result["servers"]) == 1
|
||||||
assert result["servers"][0]["name"] == "tower"
|
assert result["servers"][0]["name"] == "tower"
|
||||||
|
|
||||||
@@ -248,6 +250,7 @@ class TestUnraidInfoTool:
|
|||||||
}
|
}
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
result = await tool_fn(action="ups_devices")
|
result = await tool_fn(action="ups_devices")
|
||||||
|
assert "ups_devices" in result
|
||||||
assert len(result["ups_devices"]) == 1
|
assert len(result["ups_devices"]) == 1
|
||||||
assert result["ups_devices"][0]["model"] == "APC"
|
assert result["ups_devices"][0]["model"] == "APC"
|
||||||
|
|
||||||
|
|||||||
@@ -100,5 +100,5 @@ class TestKeysActions:
|
|||||||
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("connection lost")
|
_mock_graphql.side_effect = RuntimeError("connection lost")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="connection lost"):
|
with pytest.raises(ToolError, match="Failed to execute keys/list"):
|
||||||
await tool_fn(action="list")
|
await tool_fn(action="list")
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class TestNotificationsActions:
|
|||||||
title="Test",
|
title="Test",
|
||||||
subject="Test Subject",
|
subject="Test Subject",
|
||||||
description="Test Desc",
|
description="Test Desc",
|
||||||
importance="info",
|
importance="normal",
|
||||||
)
|
)
|
||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
|
|
||||||
@@ -149,5 +149,89 @@ class TestNotificationsActions:
|
|||||||
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
async def test_generic_exception_wraps(self, _mock_graphql: AsyncMock) -> None:
|
||||||
_mock_graphql.side_effect = RuntimeError("boom")
|
_mock_graphql.side_effect = RuntimeError("boom")
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="boom"):
|
with pytest.raises(ToolError, match="Failed to execute notifications/overview"):
|
||||||
await tool_fn(action="overview")
|
await tool_fn(action="overview")
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotificationsCreateValidation:
|
||||||
|
"""Tests for importance enum and field length validation added in this PR."""
|
||||||
|
|
||||||
|
async def test_invalid_importance_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="importance must be one of"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_info_importance_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
"""INFO is listed in old docstring examples but rejected by the validator."""
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="importance must be one of"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_alert_importance_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"notifications": {"createNotification": {"id": "n:1", "importance": "ALERT"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create", title="T", subject="S", description="D", importance="alert"
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
async def test_title_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="title must be at most 200"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="x" * 201,
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_subject_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="subject must be at most 500"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="x" * 501,
|
||||||
|
description="D",
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_description_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="description must be at most 2000"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="T",
|
||||||
|
subject="S",
|
||||||
|
description="x" * 2001,
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_title_at_max_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"notifications": {"createNotification": {"id": "n:1", "importance": "NORMAL"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create",
|
||||||
|
title="x" * 200,
|
||||||
|
subject="S",
|
||||||
|
description="D",
|
||||||
|
importance="normal",
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ def _make_tool():
|
|||||||
return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone")
|
return make_tool_fn("unraid_mcp.tools.rclone", "register_rclone_tool", "unraid_rclone")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("_mock_graphql")
|
|
||||||
class TestRcloneValidation:
|
class TestRcloneValidation:
|
||||||
async def test_delete_requires_confirm(self) -> None:
|
async def test_delete_requires_confirm(self) -> None:
|
||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
@@ -100,3 +99,83 @@ class TestRcloneActions:
|
|||||||
tool_fn = _make_tool()
|
tool_fn = _make_tool()
|
||||||
with pytest.raises(ToolError, match="Failed to delete"):
|
with pytest.raises(ToolError, match="Failed to delete"):
|
||||||
await tool_fn(action="delete_remote", name="gdrive", confirm=True)
|
await tool_fn(action="delete_remote", name="gdrive", confirm=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRcloneConfigDataValidation:
|
||||||
|
"""Tests for _validate_config_data security guards."""
|
||||||
|
|
||||||
|
async def test_path_traversal_in_key_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="disallowed characters"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"../evil": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_shell_metachar_in_key_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="disallowed characters"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"key;rm": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_too_many_keys_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="max 50"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={f"key{i}": "v" for i in range(51)},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_dict_value_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="string, number, or boolean"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"nested": {"key": "val"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_value_too_long_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="exceeds max length"):
|
||||||
|
await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"key": "x" * 4097},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_boolean_value_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"rclone": {"createRCloneRemote": {"name": "r", "type": "s3"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="s3",
|
||||||
|
config_data={"use_path_style": True},
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
async def test_int_value_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {
|
||||||
|
"rclone": {"createRCloneRemote": {"name": "r", "type": "sftp"}}
|
||||||
|
}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(
|
||||||
|
action="create_remote",
|
||||||
|
name="r",
|
||||||
|
provider_type="sftp",
|
||||||
|
config_data={"port": 22},
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
from conftest import make_tool_fn
|
from conftest import make_tool_fn
|
||||||
|
|
||||||
from unraid_mcp.core.exceptions import ToolError
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
from unraid_mcp.tools.storage import format_bytes
|
from unraid_mcp.core.utils import format_bytes, format_kb, safe_get
|
||||||
|
|
||||||
|
|
||||||
# --- Unit tests for helpers ---
|
# --- Unit tests for helpers ---
|
||||||
@@ -77,6 +77,87 @@ class TestStorageValidation:
|
|||||||
result = await tool_fn(action="logs", log_path="/var/log/syslog")
|
result = await tool_fn(action="logs", log_path="/var/log/syslog")
|
||||||
assert result["content"] == "ok"
|
assert result["content"] == "ok"
|
||||||
|
|
||||||
|
async def test_logs_tail_lines_too_large(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="tail_lines must be between"):
|
||||||
|
await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_001)
|
||||||
|
|
||||||
|
async def test_logs_tail_lines_zero_rejected(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
with pytest.raises(ToolError, match="tail_lines must be between"):
|
||||||
|
await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=0)
|
||||||
|
|
||||||
|
async def test_logs_tail_lines_at_max_accepted(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
_mock_graphql.return_value = {"logFile": {"path": "/var/log/syslog", "content": "ok"}}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(action="logs", log_path="/var/log/syslog", tail_lines=10_000)
|
||||||
|
assert result["content"] == "ok"
|
||||||
|
|
||||||
|
async def test_non_logs_action_ignores_tail_lines_validation(
|
||||||
|
self, _mock_graphql: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
_mock_graphql.return_value = {"shares": []}
|
||||||
|
tool_fn = _make_tool()
|
||||||
|
result = await tool_fn(action="shares", tail_lines=0)
|
||||||
|
assert result["shares"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatKb:
|
||||||
|
def test_none_returns_na(self) -> None:
|
||||||
|
assert format_kb(None) == "N/A"
|
||||||
|
|
||||||
|
def test_invalid_string_returns_na(self) -> None:
|
||||||
|
assert format_kb("not-a-number") == "N/A"
|
||||||
|
|
||||||
|
def test_kilobytes_range(self) -> None:
|
||||||
|
assert format_kb(512) == "512.00 KB"
|
||||||
|
|
||||||
|
def test_megabytes_range(self) -> None:
|
||||||
|
assert format_kb(2048) == "2.00 MB"
|
||||||
|
|
||||||
|
def test_gigabytes_range(self) -> None:
|
||||||
|
assert format_kb(1_048_576) == "1.00 GB"
|
||||||
|
|
||||||
|
def test_terabytes_range(self) -> None:
|
||||||
|
assert format_kb(1_073_741_824) == "1.00 TB"
|
||||||
|
|
||||||
|
def test_boundary_exactly_1024_kb(self) -> None:
|
||||||
|
# 1024 KB = 1 MB
|
||||||
|
assert format_kb(1024) == "1.00 MB"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeGet:
|
||||||
|
def test_simple_key_access(self) -> None:
|
||||||
|
assert safe_get({"a": 1}, "a") == 1
|
||||||
|
|
||||||
|
def test_nested_key_access(self) -> None:
|
||||||
|
assert safe_get({"a": {"b": "val"}}, "a", "b") == "val"
|
||||||
|
|
||||||
|
def test_missing_key_returns_none(self) -> None:
|
||||||
|
assert safe_get({"a": 1}, "missing") is None
|
||||||
|
|
||||||
|
def test_none_intermediate_returns_default(self) -> None:
|
||||||
|
assert safe_get({"a": None}, "a", "b") is None
|
||||||
|
|
||||||
|
def test_custom_default_returned(self) -> None:
|
||||||
|
assert safe_get({}, "x", default="fallback") == "fallback"
|
||||||
|
|
||||||
|
def test_non_dict_intermediate_returns_default(self) -> None:
|
||||||
|
assert safe_get({"a": "string"}, "a", "b") is None
|
||||||
|
|
||||||
|
def test_empty_list_default(self) -> None:
|
||||||
|
result = safe_get({}, "missing", default=[])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_zero_value_not_replaced_by_default(self) -> None:
|
||||||
|
assert safe_get({"temp": 0}, "temp", default="N/A") == 0
|
||||||
|
|
||||||
|
def test_false_value_not_replaced_by_default(self) -> None:
|
||||||
|
assert safe_get({"active": False}, "active", default=True) is False
|
||||||
|
|
||||||
|
def test_empty_string_not_replaced_by_default(self) -> None:
|
||||||
|
assert safe_get({"name": ""}, "name", default="unknown") == ""
|
||||||
|
|
||||||
|
|
||||||
class TestStorageActions:
|
class TestStorageActions:
|
||||||
async def test_shares(self, _mock_graphql: AsyncMock) -> None:
|
async def test_shares(self, _mock_graphql: AsyncMock) -> None:
|
||||||
|
|||||||
156
tests/test_subscription_manager.py
Normal file
156
tests/test_subscription_manager.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Tests for _cap_log_content in subscriptions/manager.py.
|
||||||
|
|
||||||
|
_cap_log_content is a pure utility that prevents unbounded memory growth from
|
||||||
|
log subscription data. It must: return a NEW dict (not mutate), recursively
|
||||||
|
cap nested 'content' fields, and only truncate when both byte limit and line
|
||||||
|
limit are exceeded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from unraid_mcp.subscriptions.manager import _cap_log_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentImmutability:
|
||||||
|
"""The function must return a new dict — never mutate the input."""
|
||||||
|
|
||||||
|
def test_returns_new_dict(self) -> None:
|
||||||
|
data = {"key": "value"}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result is not data
|
||||||
|
|
||||||
|
def test_input_not_mutated_on_passthrough(self) -> None:
|
||||||
|
data = {"content": "short text", "other": "value"}
|
||||||
|
original_content = data["content"]
|
||||||
|
_cap_log_content(data)
|
||||||
|
assert data["content"] == original_content
|
||||||
|
|
||||||
|
def test_input_not_mutated_on_truncation(self) -> None:
|
||||||
|
# Use small limits so the truncation path is exercised
|
||||||
|
large_content = "\n".join(f"line {i}" for i in range(200))
|
||||||
|
data = {"content": large_content}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
_cap_log_content(data)
|
||||||
|
# Original data must be unchanged
|
||||||
|
assert data["content"] == large_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentSmallData:
|
||||||
|
"""Content below the byte limit must be returned unchanged."""
|
||||||
|
|
||||||
|
def test_small_content_unchanged(self) -> None:
|
||||||
|
data = {"content": "just a few lines\nof log data\n"}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result["content"] == data["content"]
|
||||||
|
|
||||||
|
def test_non_content_keys_passed_through(self) -> None:
|
||||||
|
data = {"name": "cpu_subscription", "timestamp": "2026-02-18T00:00:00Z"}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_integer_value_passed_through(self) -> None:
|
||||||
|
data = {"count": 42, "active": True}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentTruncation:
|
||||||
|
"""Content exceeding both byte AND line limits must be truncated to the last N lines."""
|
||||||
|
|
||||||
|
def test_oversized_content_truncated_and_byte_capped(self) -> None:
|
||||||
|
# 200 lines, tiny byte limit: must keep recent content within byte cap.
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"content": "\n".join(lines)}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
result_lines = result["content"].splitlines()
|
||||||
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
# Must keep the most recent line suffix.
|
||||||
|
assert result_lines[-1] == "line 199"
|
||||||
|
|
||||||
|
def test_content_with_fewer_lines_than_limit_still_honors_byte_cap(self) -> None:
|
||||||
|
"""If byte limit is exceeded, output must still be capped even with few lines."""
|
||||||
|
# 30 lines, byte limit 10, line limit 50 -> must cap bytes regardless of line count
|
||||||
|
lines = [f"line {i}" for i in range(30)]
|
||||||
|
data = {"content": "\n".join(lines)}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
|
def test_non_content_keys_preserved_alongside_truncated_content(self) -> None:
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"content": "\n".join(lines), "path": "/var/log/syslog", "total_lines": 200}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result["path"] == "/var/log/syslog"
|
||||||
|
assert result["total_lines"] == 200
|
||||||
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentNested:
|
||||||
|
"""Nested 'content' fields inside sub-dicts must also be capped recursively."""
|
||||||
|
|
||||||
|
def test_nested_content_field_capped(self) -> None:
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"logFile": {"content": "\n".join(lines), "path": "/var/log/syslog"}}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["logFile"]["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
assert result["logFile"]["path"] == "/var/log/syslog"
|
||||||
|
|
||||||
|
def test_deeply_nested_content_capped(self) -> None:
|
||||||
|
lines = [f"line {i}" for i in range(200)]
|
||||||
|
data = {"outer": {"inner": {"content": "\n".join(lines)}}}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 50),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["outer"]["inner"]["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
|
def test_nested_non_content_keys_unaffected(self) -> None:
|
||||||
|
data = {"metrics": {"cpu": 42.5, "memory": 8192}}
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapLogContentSingleMassiveLine:
|
||||||
|
"""A single line larger than the byte cap must be hard-capped at byte level."""
|
||||||
|
|
||||||
|
def test_single_massive_line_hard_caps_bytes(self) -> None:
|
||||||
|
# One line, no newlines, larger than the byte cap.
|
||||||
|
# The while-loop can't reduce it (len(lines) == 1), so the
|
||||||
|
# last-resort byte-slice path at manager.py:65-69 must fire.
|
||||||
|
huge_content = "x" * 200
|
||||||
|
data = {"content": huge_content}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
|
||||||
|
):
|
||||||
|
result = _cap_log_content(data)
|
||||||
|
assert len(result["content"].encode("utf-8", errors="replace")) <= 10
|
||||||
|
|
||||||
|
def test_single_massive_line_input_not_mutated(self) -> None:
|
||||||
|
huge_content = "x" * 200
|
||||||
|
data = {"content": huge_content}
|
||||||
|
with (
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_BYTES", 10),
|
||||||
|
patch("unraid_mcp.subscriptions.manager._MAX_RESOURCE_DATA_LINES", 5_000),
|
||||||
|
):
|
||||||
|
_cap_log_content(data)
|
||||||
|
assert data["content"] == huge_content
|
||||||
131
tests/test_subscription_validation.py
Normal file
131
tests/test_subscription_validation.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Tests for _validate_subscription_query in diagnostics.py.
|
||||||
|
|
||||||
|
Security-critical: this function is the only guard against arbitrary GraphQL
|
||||||
|
operations (mutations, queries) being sent over the WebSocket subscription channel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from unraid_mcp.core.exceptions import ToolError
|
||||||
|
from unraid_mcp.subscriptions.diagnostics import (
|
||||||
|
_ALLOWED_SUBSCRIPTION_NAMES,
|
||||||
|
_validate_subscription_query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryAllowed:
|
||||||
|
"""All whitelisted subscription names must be accepted."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sub_name", sorted(_ALLOWED_SUBSCRIPTION_NAMES))
|
||||||
|
def test_all_allowed_names_accepted(self, sub_name: str) -> None:
|
||||||
|
query = f"subscription {{ {sub_name} {{ data }} }}"
|
||||||
|
result = _validate_subscription_query(query)
|
||||||
|
assert result == sub_name
|
||||||
|
|
||||||
|
def test_returns_extracted_subscription_name(self) -> None:
|
||||||
|
query = "subscription { cpuSubscription { usage } }"
|
||||||
|
assert _validate_subscription_query(query) == "cpuSubscription"
|
||||||
|
|
||||||
|
def test_leading_whitespace_accepted(self) -> None:
|
||||||
|
query = " subscription { memorySubscription { free } }"
|
||||||
|
assert _validate_subscription_query(query) == "memorySubscription"
|
||||||
|
|
||||||
|
def test_multiline_query_accepted(self) -> None:
|
||||||
|
query = "subscription {\n logFileSubscription {\n content\n }\n}"
|
||||||
|
assert _validate_subscription_query(query) == "logFileSubscription"
|
||||||
|
|
||||||
|
def test_case_insensitive_subscription_keyword(self) -> None:
|
||||||
|
"""'SUBSCRIPTION' should be accepted (regex uses IGNORECASE)."""
|
||||||
|
query = "SUBSCRIPTION { cpuSubscription { usage } }"
|
||||||
|
assert _validate_subscription_query(query) == "cpuSubscription"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryForbiddenKeywords:
|
||||||
|
"""Queries containing 'mutation' or 'query' as standalone keywords must be rejected."""
|
||||||
|
|
||||||
|
def test_mutation_keyword_rejected(self) -> None:
|
||||||
|
query = 'mutation { docker { start(id: "abc") } }'
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_query_keyword_rejected(self) -> None:
|
||||||
|
query = "query { info { os { platform } } }"
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_mutation_embedded_in_subscription_rejected(self) -> None:
|
||||||
|
"""'mutation' anywhere in the string triggers rejection."""
|
||||||
|
query = "subscription { cpuSubscription { mutation data } }"
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_query_embedded_in_subscription_rejected(self) -> None:
|
||||||
|
query = "subscription { cpuSubscription { query data } }"
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_mutation_case_insensitive_rejection(self) -> None:
|
||||||
|
query = 'MUTATION { docker { start(id: "abc") } }'
|
||||||
|
with pytest.raises(ToolError, match="must be a subscription"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_mutation_field_identifier_not_rejected(self) -> None:
|
||||||
|
"""'mutationField' as an identifier must NOT be rejected — only standalone 'mutation'."""
|
||||||
|
# This tests the \b word boundary in _FORBIDDEN_KEYWORDS
|
||||||
|
query = "subscription { cpuSubscription { mutationField } }"
|
||||||
|
# Should not raise — "mutationField" is an identifier, not the keyword
|
||||||
|
result = _validate_subscription_query(query)
|
||||||
|
assert result == "cpuSubscription"
|
||||||
|
|
||||||
|
def test_query_field_identifier_not_rejected(self) -> None:
|
||||||
|
"""'queryResult' as an identifier must NOT be rejected."""
|
||||||
|
query = "subscription { cpuSubscription { queryResult } }"
|
||||||
|
result = _validate_subscription_query(query)
|
||||||
|
assert result == "cpuSubscription"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryInvalidFormat:
|
||||||
|
"""Queries that don't match the expected subscription format must be rejected."""
|
||||||
|
|
||||||
|
def test_empty_string_rejected(self) -> None:
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("")
|
||||||
|
|
||||||
|
def test_plain_identifier_rejected(self) -> None:
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("cpuSubscription { usage }")
|
||||||
|
|
||||||
|
def test_missing_operation_body_rejected(self) -> None:
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("subscription")
|
||||||
|
|
||||||
|
def test_subscription_without_field_rejected(self) -> None:
|
||||||
|
"""subscription { } with no field name doesn't match the pattern."""
|
||||||
|
with pytest.raises(ToolError, match="must start with 'subscription'"):
|
||||||
|
_validate_subscription_query("subscription { }")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSubscriptionQueryUnknownName:
|
||||||
|
"""Subscription names not in the whitelist must be rejected even if format is valid."""
|
||||||
|
|
||||||
|
def test_unknown_subscription_name_rejected(self) -> None:
|
||||||
|
query = "subscription { unknownSubscription { data } }"
|
||||||
|
with pytest.raises(ToolError, match="not allowed"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_error_message_includes_allowed_list(self) -> None:
|
||||||
|
"""Error message must list the allowed subscription names for usability."""
|
||||||
|
query = "subscription { badSub { data } }"
|
||||||
|
with pytest.raises(ToolError, match="Allowed subscriptions"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_arbitrary_field_name_rejected(self) -> None:
|
||||||
|
query = "subscription { users { id email } }"
|
||||||
|
with pytest.raises(ToolError, match="not allowed"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
|
|
||||||
|
def test_close_but_not_whitelisted_rejected(self) -> None:
|
||||||
|
"""'cpu' without 'Subscription' suffix is not in the allow-list."""
|
||||||
|
query = "subscription { cpu { usage } }"
|
||||||
|
with pytest.raises(ToolError, match="not allowed"):
|
||||||
|
_validate_subscription_query(query)
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Unraid MCP Server Package.
|
"""Unraid MCP Server Package."""
|
||||||
|
|
||||||
A modular MCP (Model Context Protocol) server that provides tools to interact
|
from .version import VERSION
|
||||||
with an Unraid server's GraphQL API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__version__ = "0.2.0"
|
|
||||||
|
__version__ = VERSION
|
||||||
|
|||||||
@@ -5,16 +5,10 @@ that cap at 10MB and start over (no rotation) for consistent use across all modu
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytz
|
|
||||||
from rich.align import Align
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.rule import Rule
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -28,7 +22,7 @@ from .settings import LOG_FILE_PATH, LOG_LEVEL_STR
|
|||||||
|
|
||||||
|
|
||||||
# Global Rich console for consistent formatting
|
# Global Rich console for consistent formatting
|
||||||
console = Console(stderr=True, force_terminal=True)
|
console = Console(stderr=True)
|
||||||
|
|
||||||
|
|
||||||
class OverwriteFileHandler(logging.FileHandler):
|
class OverwriteFileHandler(logging.FileHandler):
|
||||||
@@ -45,12 +39,18 @@ class OverwriteFileHandler(logging.FileHandler):
|
|||||||
delay: Whether to delay file opening
|
delay: Whether to delay file opening
|
||||||
"""
|
"""
|
||||||
self.max_bytes = max_bytes
|
self.max_bytes = max_bytes
|
||||||
|
self._emit_count = 0
|
||||||
|
self._check_interval = 100
|
||||||
super().__init__(filename, mode, encoding, delay)
|
super().__init__(filename, mode, encoding, delay)
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
"""Emit a record, checking file size and overwriting if needed."""
|
"""Emit a record, checking file size periodically and overwriting if needed."""
|
||||||
# Check file size before writing
|
self._emit_count += 1
|
||||||
if self.stream and hasattr(self.stream, "name"):
|
if (
|
||||||
|
(self._emit_count == 1 or self._emit_count % self._check_interval == 0)
|
||||||
|
and self.stream
|
||||||
|
and hasattr(self.stream, "name")
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
base_path = Path(self.baseFilename)
|
base_path = Path(self.baseFilename)
|
||||||
if base_path.exists():
|
if base_path.exists():
|
||||||
@@ -91,6 +91,28 @@ class OverwriteFileHandler(logging.FileHandler):
|
|||||||
super().emit(record)
|
super().emit(record)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_shared_file_handler() -> OverwriteFileHandler:
|
||||||
|
"""Create the single shared file handler for all loggers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured OverwriteFileHandler instance
|
||||||
|
"""
|
||||||
|
numeric_log_level = getattr(logging, LOG_LEVEL_STR, logging.INFO)
|
||||||
|
handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
|
||||||
|
handler.setLevel(numeric_log_level)
|
||||||
|
handler.setFormatter(
|
||||||
|
logging.Formatter(
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
# Single shared file handler — all loggers reuse this instance to avoid
|
||||||
|
# race conditions from multiple OverwriteFileHandler instances on the same file.
|
||||||
|
_shared_file_handler = _create_shared_file_handler()
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
|
def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
|
||||||
"""Set up and configure the logger with console and file handlers.
|
"""Set up and configure the logger with console and file handlers.
|
||||||
|
|
||||||
@@ -118,19 +140,13 @@ def setup_logger(name: str = "UnraidMCPServer") -> logging.Logger:
|
|||||||
show_level=True,
|
show_level=True,
|
||||||
show_path=False,
|
show_path=False,
|
||||||
rich_tracebacks=True,
|
rich_tracebacks=True,
|
||||||
tracebacks_show_locals=True,
|
tracebacks_show_locals=False,
|
||||||
)
|
)
|
||||||
console_handler.setLevel(numeric_log_level)
|
console_handler.setLevel(numeric_log_level)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
# File Handler with 10MB cap (overwrites instead of rotating)
|
# Reuse the shared file handler
|
||||||
file_handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
|
logger.addHandler(_shared_file_handler)
|
||||||
file_handler.setLevel(numeric_log_level)
|
|
||||||
file_formatter = logging.Formatter(
|
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(file_formatter)
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
@@ -157,20 +173,14 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
|
|||||||
show_level=True,
|
show_level=True,
|
||||||
show_path=False,
|
show_path=False,
|
||||||
rich_tracebacks=True,
|
rich_tracebacks=True,
|
||||||
tracebacks_show_locals=True,
|
tracebacks_show_locals=False,
|
||||||
markup=True,
|
markup=True,
|
||||||
)
|
)
|
||||||
console_handler.setLevel(numeric_log_level)
|
console_handler.setLevel(numeric_log_level)
|
||||||
fastmcp_logger.addHandler(console_handler)
|
fastmcp_logger.addHandler(console_handler)
|
||||||
|
|
||||||
# File Handler with 10MB cap (overwrites instead of rotating)
|
# Reuse the shared file handler
|
||||||
file_handler = OverwriteFileHandler(LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8")
|
fastmcp_logger.addHandler(_shared_file_handler)
|
||||||
file_handler.setLevel(numeric_log_level)
|
|
||||||
file_formatter = logging.Formatter(
|
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(module)s - %(funcName)s - %(lineno)d - %(message)s"
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(file_formatter)
|
|
||||||
fastmcp_logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
fastmcp_logger.setLevel(numeric_log_level)
|
fastmcp_logger.setLevel(numeric_log_level)
|
||||||
|
|
||||||
@@ -186,30 +196,19 @@ def configure_fastmcp_logger_with_rich() -> logging.Logger | None:
|
|||||||
show_level=True,
|
show_level=True,
|
||||||
show_path=False,
|
show_path=False,
|
||||||
rich_tracebacks=True,
|
rich_tracebacks=True,
|
||||||
tracebacks_show_locals=True,
|
tracebacks_show_locals=False,
|
||||||
markup=True,
|
markup=True,
|
||||||
)
|
)
|
||||||
root_console_handler.setLevel(numeric_log_level)
|
root_console_handler.setLevel(numeric_log_level)
|
||||||
root_logger.addHandler(root_console_handler)
|
root_logger.addHandler(root_console_handler)
|
||||||
|
|
||||||
# File Handler for root logger with 10MB cap (overwrites instead of rotating)
|
# Reuse the shared file handler for root logger
|
||||||
root_file_handler = OverwriteFileHandler(
|
root_logger.addHandler(_shared_file_handler)
|
||||||
LOG_FILE_PATH, max_bytes=10 * 1024 * 1024, encoding="utf-8"
|
|
||||||
)
|
|
||||||
root_file_handler.setLevel(numeric_log_level)
|
|
||||||
root_file_handler.setFormatter(file_formatter)
|
|
||||||
root_logger.addHandler(root_file_handler)
|
|
||||||
root_logger.setLevel(numeric_log_level)
|
root_logger.setLevel(numeric_log_level)
|
||||||
|
|
||||||
return fastmcp_logger
|
return fastmcp_logger
|
||||||
|
|
||||||
|
|
||||||
def setup_uvicorn_logging() -> logging.Logger | None:
|
|
||||||
"""Configure uvicorn and other third-party loggers to use Rich formatting."""
|
|
||||||
# This function is kept for backward compatibility but now delegates to FastMCP
|
|
||||||
return configure_fastmcp_logger_with_rich()
|
|
||||||
|
|
||||||
|
|
||||||
def log_configuration_status(logger: logging.Logger) -> None:
|
def log_configuration_status(logger: logging.Logger) -> None:
|
||||||
"""Log configuration status at startup.
|
"""Log configuration status at startup.
|
||||||
|
|
||||||
@@ -242,97 +241,6 @@ def log_configuration_status(logger: logging.Logger) -> None:
|
|||||||
logger.error(f"Missing required configuration: {config['missing_config']}")
|
logger.error(f"Missing required configuration: {config['missing_config']}")
|
||||||
|
|
||||||
|
|
||||||
# Development logging helpers for Rich formatting
|
|
||||||
def get_est_timestamp() -> str:
|
|
||||||
"""Get current timestamp in EST timezone with YY/MM/DD format."""
|
|
||||||
est = pytz.timezone("US/Eastern")
|
|
||||||
now = datetime.now(est)
|
|
||||||
return now.strftime("%y/%m/%d %H:%M:%S")
|
|
||||||
|
|
||||||
|
|
||||||
def log_header(title: str) -> None:
|
|
||||||
"""Print a beautiful header panel with Nordic blue styling."""
|
|
||||||
panel = Panel(
|
|
||||||
Align.center(Text(title, style="bold white")),
|
|
||||||
style="#5E81AC", # Nordic blue
|
|
||||||
padding=(0, 2),
|
|
||||||
border_style="#81A1C1", # Light Nordic blue
|
|
||||||
)
|
|
||||||
console.print(panel)
|
|
||||||
|
|
||||||
|
|
||||||
def log_with_level_and_indent(message: str, level: str = "info", indent: int = 0) -> None:
|
|
||||||
"""Log a message with specific level and indentation."""
|
|
||||||
timestamp = get_est_timestamp()
|
|
||||||
indent_str = " " * indent
|
|
||||||
|
|
||||||
# Enhanced Nordic color scheme with more blues
|
|
||||||
level_config = {
|
|
||||||
"error": {"color": "#BF616A", "icon": "❌", "style": "bold"}, # Nordic red
|
|
||||||
"warning": {"color": "#EBCB8B", "icon": "⚠️", "style": ""}, # Nordic yellow
|
|
||||||
"success": {"color": "#A3BE8C", "icon": "✅", "style": "bold"}, # Nordic green
|
|
||||||
"info": {"color": "#5E81AC", "icon": "\u2139\ufe0f", "style": "bold"}, # Nordic blue (bold)
|
|
||||||
"status": {"color": "#81A1C1", "icon": "🔍", "style": ""}, # Light Nordic blue
|
|
||||||
"debug": {"color": "#4C566A", "icon": "🐛", "style": ""}, # Nordic dark gray
|
|
||||||
}
|
|
||||||
|
|
||||||
config = level_config.get(
|
|
||||||
level, {"color": "#81A1C1", "icon": "•", "style": ""}
|
|
||||||
) # Default to light Nordic blue
|
|
||||||
|
|
||||||
# Create beautifully formatted text
|
|
||||||
text = Text()
|
|
||||||
|
|
||||||
# Timestamp with Nordic blue styling
|
|
||||||
text.append(f"[{timestamp}]", style="#81A1C1") # Light Nordic blue for timestamps
|
|
||||||
text.append(" ")
|
|
||||||
|
|
||||||
# Indentation with Nordic blue styling
|
|
||||||
if indent > 0:
|
|
||||||
text.append(indent_str, style="#81A1C1")
|
|
||||||
|
|
||||||
# Level icon (only for certain levels)
|
|
||||||
if level in ["error", "warning", "success"]:
|
|
||||||
# Extract emoji from message if it starts with one, to avoid duplication
|
|
||||||
if message and len(message) > 0 and ord(message[0]) >= 0x1F600: # Emoji range
|
|
||||||
# Message already has emoji, don't add icon
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
text.append(f"{config['icon']} ", style=config["color"])
|
|
||||||
|
|
||||||
# Message content
|
|
||||||
message_style = f"{config['color']} {config['style']}".strip()
|
|
||||||
text.append(message, style=message_style)
|
|
||||||
|
|
||||||
console.print(text)
|
|
||||||
|
|
||||||
|
|
||||||
def log_separator() -> None:
|
|
||||||
"""Print a beautiful separator line with Nordic blue styling."""
|
|
||||||
console.print(Rule(style="#81A1C1"))
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions for different log levels
|
|
||||||
def log_error(message: str, indent: int = 0) -> None:
|
|
||||||
log_with_level_and_indent(message, "error", indent)
|
|
||||||
|
|
||||||
|
|
||||||
def log_warning(message: str, indent: int = 0) -> None:
|
|
||||||
log_with_level_and_indent(message, "warning", indent)
|
|
||||||
|
|
||||||
|
|
||||||
def log_success(message: str, indent: int = 0) -> None:
|
|
||||||
log_with_level_and_indent(message, "success", indent)
|
|
||||||
|
|
||||||
|
|
||||||
def log_info(message: str, indent: int = 0) -> None:
|
|
||||||
log_with_level_and_indent(message, "info", indent)
|
|
||||||
|
|
||||||
|
|
||||||
def log_status(message: str, indent: int = 0) -> None:
|
|
||||||
log_with_level_and_indent(message, "status", indent)
|
|
||||||
|
|
||||||
|
|
||||||
# Global logger instance - modules can import this directly
|
# Global logger instance - modules can import this directly
|
||||||
if FASTMCP_AVAILABLE:
|
if FASTMCP_AVAILABLE:
|
||||||
# Use FastMCP logger with Rich formatting
|
# Use FastMCP logger with Rich formatting
|
||||||
@@ -341,5 +249,3 @@ if FASTMCP_AVAILABLE:
|
|||||||
else:
|
else:
|
||||||
# Fallback to our custom logger if FastMCP is not available
|
# Fallback to our custom logger if FastMCP is not available
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
# Setup uvicorn logging when module is imported
|
|
||||||
setup_uvicorn_logging()
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from ..version import VERSION as APP_VERSION
|
||||||
|
|
||||||
|
|
||||||
# Get the script directory (config module location)
|
# Get the script directory (config module location)
|
||||||
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
|
SCRIPT_DIR = Path(__file__).parent # /home/user/code/unraid-mcp/unraid_mcp/config/
|
||||||
@@ -30,16 +32,13 @@ for dotenv_path in dotenv_paths:
|
|||||||
load_dotenv(dotenv_path=dotenv_path)
|
load_dotenv(dotenv_path=dotenv_path)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Application Version
|
|
||||||
VERSION = "0.2.0"
|
|
||||||
|
|
||||||
# Core API Configuration
|
# Core API Configuration
|
||||||
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
|
UNRAID_API_URL = os.getenv("UNRAID_API_URL")
|
||||||
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
|
UNRAID_API_KEY = os.getenv("UNRAID_API_KEY")
|
||||||
|
|
||||||
# Server Configuration
|
# Server Configuration
|
||||||
UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970"))
|
UNRAID_MCP_PORT = int(os.getenv("UNRAID_MCP_PORT", "6970"))
|
||||||
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0")
|
UNRAID_MCP_HOST = os.getenv("UNRAID_MCP_HOST", "0.0.0.0") # noqa: S104 — intentional for Docker
|
||||||
UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower()
|
UNRAID_MCP_TRANSPORT = os.getenv("UNRAID_MCP_TRANSPORT", "streamable-http").lower()
|
||||||
|
|
||||||
# SSL Configuration
|
# SSL Configuration
|
||||||
@@ -54,11 +53,18 @@ else: # Path to CA bundle
|
|||||||
# Logging Configuration
|
# Logging Configuration
|
||||||
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
|
LOG_LEVEL_STR = os.getenv("UNRAID_MCP_LOG_LEVEL", "INFO").upper()
|
||||||
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
|
LOG_FILE_NAME = os.getenv("UNRAID_MCP_LOG_FILE", "unraid-mcp.log")
|
||||||
LOGS_DIR = Path("/tmp")
|
# Use /.dockerenv as the container indicator for robust Docker detection.
|
||||||
|
IS_DOCKER = Path("/.dockerenv").exists()
|
||||||
|
LOGS_DIR = Path("/app/logs") if IS_DOCKER else PROJECT_ROOT / "logs"
|
||||||
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
||||||
|
|
||||||
# Ensure logs directory exists
|
# Ensure logs directory exists; if creation fails, fall back to /tmp.
|
||||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
try:
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError:
|
||||||
|
LOGS_DIR = PROJECT_ROOT / ".cache" / "logs"
|
||||||
|
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
LOG_FILE_PATH = LOGS_DIR / LOG_FILE_NAME
|
||||||
|
|
||||||
# HTTP Client Configuration
|
# HTTP Client Configuration
|
||||||
TIMEOUT_CONFIG = {
|
TIMEOUT_CONFIG = {
|
||||||
@@ -104,3 +110,5 @@ def get_config_summary() -> dict[str, Any]:
|
|||||||
"config_valid": is_valid,
|
"config_valid": is_valid,
|
||||||
"missing_config": missing if not is_valid else None,
|
"missing_config": missing if not is_valid else None,
|
||||||
}
|
}
|
||||||
|
# Re-export application version from a single source of truth.
|
||||||
|
VERSION = APP_VERSION
|
||||||
|
|||||||
@@ -5,8 +5,11 @@ to the Unraid API with proper timeout handling and error management.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Any, Final
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -21,8 +24,22 @@ from ..config.settings import (
|
|||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError
|
||||||
|
|
||||||
|
|
||||||
# Sensitive keys to redact from debug logs
|
# Sensitive keys to redact from debug logs (frozenset — immutable, Final — no accidental reassignment)
|
||||||
_SENSITIVE_KEYS = {"password", "key", "secret", "token", "apikey"}
|
_SENSITIVE_KEYS: Final[frozenset[str]] = frozenset(
|
||||||
|
{
|
||||||
|
"password",
|
||||||
|
"key",
|
||||||
|
"secret",
|
||||||
|
"token",
|
||||||
|
"apikey",
|
||||||
|
"authorization",
|
||||||
|
"cookie",
|
||||||
|
"session",
|
||||||
|
"credential",
|
||||||
|
"passphrase",
|
||||||
|
"jwt",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_sensitive_key(key: str) -> bool:
|
def _is_sensitive_key(key: str) -> bool:
|
||||||
@@ -31,14 +48,14 @@ def _is_sensitive_key(key: str) -> bool:
|
|||||||
return any(s in key_lower for s in _SENSITIVE_KEYS)
|
return any(s in key_lower for s in _SENSITIVE_KEYS)
|
||||||
|
|
||||||
|
|
||||||
def _redact_sensitive(obj: Any) -> Any:
|
def redact_sensitive(obj: Any) -> Any:
|
||||||
"""Recursively redact sensitive values from nested dicts/lists."""
|
"""Recursively redact sensitive values from nested dicts/lists."""
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return {
|
return {
|
||||||
k: ("***" if _is_sensitive_key(k) else _redact_sensitive(v)) for k, v in obj.items()
|
k: ("***" if _is_sensitive_key(k) else redact_sensitive(v)) for k, v in obj.items()
|
||||||
}
|
}
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return [_redact_sensitive(item) for item in obj]
|
return [redact_sensitive(item) for item in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@@ -66,8 +83,116 @@ def get_timeout_for_operation(profile: str) -> httpx.Timeout:
|
|||||||
|
|
||||||
|
|
||||||
# Global connection pool (module-level singleton)
|
# Global connection pool (module-level singleton)
|
||||||
|
# Python 3.12+ asyncio.Lock() is safe at module level — no running event loop required
|
||||||
_http_client: httpx.AsyncClient | None = None
|
_http_client: httpx.AsyncClient | None = None
|
||||||
_client_lock = asyncio.Lock()
|
_client_lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class _RateLimiter:
|
||||||
|
"""Token bucket rate limiter for Unraid API (100 req / 10s hard limit).
|
||||||
|
|
||||||
|
Uses 90 tokens with 9.0 tokens/sec refill for 10% safety headroom.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_tokens: int = 90, refill_rate: float = 9.0) -> None:
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.tokens = float(max_tokens)
|
||||||
|
self.refill_rate = refill_rate # tokens per second
|
||||||
|
self.last_refill = time.monotonic()
|
||||||
|
# asyncio.Lock() is safe to create at __init__ time (Python 3.12+)
|
||||||
|
self._lock: Final[asyncio.Lock] = asyncio.Lock()
|
||||||
|
|
||||||
|
def _refill(self) -> None:
|
||||||
|
"""Refill tokens based on elapsed time."""
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - self.last_refill
|
||||||
|
self.tokens = min(self.max_tokens, self.tokens + elapsed * self.refill_rate)
|
||||||
|
self.last_refill = now
|
||||||
|
|
||||||
|
async def acquire(self) -> None:
|
||||||
|
"""Consume one token, waiting if necessary for refill."""
|
||||||
|
while True:
|
||||||
|
async with self._lock:
|
||||||
|
self._refill()
|
||||||
|
if self.tokens >= 1:
|
||||||
|
self.tokens -= 1
|
||||||
|
return
|
||||||
|
wait_time = (1 - self.tokens) / self.refill_rate
|
||||||
|
|
||||||
|
# Sleep outside the lock so other coroutines aren't blocked
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
|
||||||
|
_rate_limiter = _RateLimiter()
|
||||||
|
|
||||||
|
|
||||||
|
# --- TTL Cache for stable read-only queries ---
|
||||||
|
|
||||||
|
# Queries whose results change infrequently and are safe to cache.
|
||||||
|
# Mutations and volatile queries (metrics, docker, array state) are excluded.
|
||||||
|
_CACHEABLE_QUERY_PREFIXES = frozenset(
|
||||||
|
{
|
||||||
|
"GetNetworkConfig",
|
||||||
|
"GetRegistrationInfo",
|
||||||
|
"GetOwner",
|
||||||
|
"GetFlash",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_CACHE_TTL_SECONDS = 60.0
|
||||||
|
_OPERATION_NAME_PATTERN = re.compile(r"^(?:query\s+)?([_A-Za-z][_0-9A-Za-z]*)\b")
|
||||||
|
|
||||||
|
|
||||||
|
class _QueryCache:
|
||||||
|
"""Simple TTL cache for GraphQL query responses.
|
||||||
|
|
||||||
|
Keyed by a hash of (query, variables). Entries expire after _CACHE_TTL_SECONDS.
|
||||||
|
Only caches responses for queries whose operation name is in _CACHEABLE_QUERY_PREFIXES.
|
||||||
|
Mutation requests always bypass the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._store: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cache_key(query: str, variables: dict[str, Any] | None) -> str:
|
||||||
|
raw = query + json.dumps(variables or {}, sort_keys=True)
|
||||||
|
return hashlib.sha256(raw.encode()).hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cacheable(query: str) -> bool:
|
||||||
|
"""Check if a query is eligible for caching based on its operation name."""
|
||||||
|
normalized = query.lstrip()
|
||||||
|
if normalized.startswith("mutation"):
|
||||||
|
return False
|
||||||
|
match = _OPERATION_NAME_PATTERN.match(normalized)
|
||||||
|
if not match:
|
||||||
|
return False
|
||||||
|
return match.group(1) in _CACHEABLE_QUERY_PREFIXES
|
||||||
|
|
||||||
|
def get(self, query: str, variables: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
"""Return cached result if present and not expired, else None."""
|
||||||
|
key = self._cache_key(query, variables)
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
expires_at, data = entry
|
||||||
|
if time.monotonic() > expires_at:
|
||||||
|
del self._store[key]
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
def put(self, query: str, variables: dict[str, Any] | None, data: dict[str, Any]) -> None:
|
||||||
|
"""Store a query result with TTL expiry."""
|
||||||
|
key = self._cache_key(query, variables)
|
||||||
|
self._store[key] = (time.monotonic() + _CACHE_TTL_SECONDS, data)
|
||||||
|
|
||||||
|
def invalidate_all(self) -> None:
|
||||||
|
"""Clear the entire cache (called after mutations)."""
|
||||||
|
self._store.clear()
|
||||||
|
|
||||||
|
|
||||||
|
_query_cache = _QueryCache()
|
||||||
|
|
||||||
|
|
||||||
def is_idempotent_error(error_message: str, operation: str) -> bool:
|
def is_idempotent_error(error_message: str, operation: str) -> bool:
|
||||||
@@ -109,7 +234,7 @@ async def _create_http_client() -> httpx.AsyncClient:
|
|||||||
return httpx.AsyncClient(
|
return httpx.AsyncClient(
|
||||||
# Connection pool settings
|
# Connection pool settings
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_keepalive_connections=20, max_connections=100, keepalive_expiry=30.0
|
max_keepalive_connections=20, max_connections=20, keepalive_expiry=30.0
|
||||||
),
|
),
|
||||||
# Default timeout (can be overridden per-request)
|
# Default timeout (can be overridden per-request)
|
||||||
timeout=DEFAULT_TIMEOUT,
|
timeout=DEFAULT_TIMEOUT,
|
||||||
@@ -123,33 +248,28 @@ async def _create_http_client() -> httpx.AsyncClient:
|
|||||||
async def get_http_client() -> httpx.AsyncClient:
|
async def get_http_client() -> httpx.AsyncClient:
|
||||||
"""Get or create shared HTTP client with connection pooling.
|
"""Get or create shared HTTP client with connection pooling.
|
||||||
|
|
||||||
The client is protected by an asyncio lock to prevent concurrent creation.
|
Uses double-checked locking: fast-path skips the lock when the client
|
||||||
If the existing client was closed (e.g., during shutdown), a new one is created.
|
is already initialized, only acquiring it for initial creation or
|
||||||
|
recovery after close.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Singleton AsyncClient instance with connection pooling enabled
|
Singleton AsyncClient instance with connection pooling enabled
|
||||||
"""
|
"""
|
||||||
global _http_client
|
global _http_client
|
||||||
|
|
||||||
|
# Fast-path: skip lock if client is already initialized and open
|
||||||
|
client = _http_client
|
||||||
|
if client is not None and not client.is_closed:
|
||||||
|
return client
|
||||||
|
|
||||||
|
# Slow-path: acquire lock for initialization
|
||||||
async with _client_lock:
|
async with _client_lock:
|
||||||
if _http_client is None or _http_client.is_closed:
|
if _http_client is None or _http_client.is_closed:
|
||||||
_http_client = await _create_http_client()
|
_http_client = await _create_http_client()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Created shared HTTP client with connection pooling (20 keepalive, 100 max connections)"
|
"Created shared HTTP client with connection pooling (20 keepalive, 20 max connections)"
|
||||||
)
|
)
|
||||||
|
return _http_client
|
||||||
client = _http_client
|
|
||||||
|
|
||||||
# Verify client is still open after releasing the lock.
|
|
||||||
# In asyncio's cooperative model this is unlikely to fail, but guards
|
|
||||||
# against edge cases where close_http_client runs between yield points.
|
|
||||||
if client.is_closed:
|
|
||||||
async with _client_lock:
|
|
||||||
_http_client = await _create_http_client()
|
|
||||||
client = _http_client
|
|
||||||
logger.info("Re-created HTTP client after unexpected close")
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
async def close_http_client() -> None:
|
async def close_http_client() -> None:
|
||||||
@@ -190,6 +310,14 @@ async def make_graphql_request(
|
|||||||
if not UNRAID_API_KEY:
|
if not UNRAID_API_KEY:
|
||||||
raise ToolError("UNRAID_API_KEY not configured")
|
raise ToolError("UNRAID_API_KEY not configured")
|
||||||
|
|
||||||
|
# Check TTL cache for stable read-only queries
|
||||||
|
is_mutation = query.lstrip().startswith("mutation")
|
||||||
|
if not is_mutation and _query_cache.is_cacheable(query):
|
||||||
|
cached = _query_cache.get(query, variables)
|
||||||
|
if cached is not None:
|
||||||
|
logger.debug("Returning cached response for query")
|
||||||
|
return cached
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"X-API-Key": UNRAID_API_KEY,
|
"X-API-Key": UNRAID_API_KEY,
|
||||||
@@ -202,19 +330,41 @@ async def make_graphql_request(
|
|||||||
logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:")
|
logger.debug(f"Making GraphQL request to {UNRAID_API_URL}:")
|
||||||
logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query
|
logger.debug(f"Query: {query[:200]}{'...' if len(query) > 200 else ''}") # Log truncated query
|
||||||
if variables:
|
if variables:
|
||||||
logger.debug(f"Variables: {_redact_sensitive(variables)}")
|
logger.debug(f"Variables: {redact_sensitive(variables)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Rate limit: consume a token before making the request
|
||||||
|
await _rate_limiter.acquire()
|
||||||
|
|
||||||
# Get the shared HTTP client with connection pooling
|
# Get the shared HTTP client with connection pooling
|
||||||
client = await get_http_client()
|
client = await get_http_client()
|
||||||
|
|
||||||
# Override timeout if custom timeout specified
|
# Retry loop for 429 rate limit responses
|
||||||
|
post_kwargs: dict[str, Any] = {"json": payload, "headers": headers}
|
||||||
if custom_timeout is not None:
|
if custom_timeout is not None:
|
||||||
response = await client.post(
|
post_kwargs["timeout"] = custom_timeout
|
||||||
UNRAID_API_URL, json=payload, headers=headers, timeout=custom_timeout
|
|
||||||
|
response: httpx.Response | None = None
|
||||||
|
for attempt in range(3):
|
||||||
|
response = await client.post(UNRAID_API_URL, **post_kwargs)
|
||||||
|
if response.status_code == 429:
|
||||||
|
backoff = 2**attempt
|
||||||
|
logger.warning(
|
||||||
|
f"Rate limited (429) by Unraid API, retrying in {backoff}s (attempt {attempt + 1}/3)"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
if response is None: # pragma: no cover — guaranteed by loop
|
||||||
|
raise ToolError("No response received after retry attempts")
|
||||||
|
|
||||||
|
# Provide a clear message when all retries are exhausted on 429
|
||||||
|
if response.status_code == 429:
|
||||||
|
logger.error("Rate limit (429) persisted after 3 retries — request aborted")
|
||||||
|
raise ToolError(
|
||||||
|
"Unraid API is rate limiting requests. Wait ~10 seconds before retrying."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
response = await client.post(UNRAID_API_URL, json=payload, headers=headers)
|
|
||||||
|
|
||||||
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
|
response.raise_for_status() # Raise an exception for HTTP error codes 4xx/5xx
|
||||||
|
|
||||||
@@ -245,14 +395,27 @@ async def make_graphql_request(
|
|||||||
|
|
||||||
logger.debug("GraphQL request successful.")
|
logger.debug("GraphQL request successful.")
|
||||||
data = response_data.get("data", {})
|
data = response_data.get("data", {})
|
||||||
return data if isinstance(data, dict) else {} # Ensure we return dict
|
result = data if isinstance(data, dict) else {} # Ensure we return dict
|
||||||
|
|
||||||
|
# Invalidate cache on mutations; cache eligible query results
|
||||||
|
if is_mutation:
|
||||||
|
_query_cache.invalidate_all()
|
||||||
|
elif _query_cache.is_cacheable(query):
|
||||||
|
_query_cache.put(query, variables, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
# Log full details internally; only expose status code to MCP client
|
||||||
logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
|
logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
|
||||||
raise ToolError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
|
raise ToolError(
|
||||||
|
f"Unraid API returned HTTP {e.response.status_code}. Check server logs for details."
|
||||||
|
) from e
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
# Log full error internally; give safe summary to MCP client
|
||||||
logger.error(f"Request error occurred: {e}")
|
logger.error(f"Request error occurred: {e}")
|
||||||
raise ToolError(f"Network connection error: {e!s}") from e
|
raise ToolError(f"Network error connecting to Unraid API: {type(e).__name__}") from e
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
# Log full decode error; give safe summary to MCP client
|
||||||
logger.error(f"Failed to decode JSON response: {e}")
|
logger.error(f"Failed to decode JSON response: {e}")
|
||||||
raise ToolError(f"Invalid JSON response from Unraid API: {e!s}") from e
|
raise ToolError("Unraid API returned an invalid response (not valid JSON)") from e
|
||||||
|
|||||||
@@ -4,6 +4,10 @@ This module defines custom exception classes for consistent error handling
|
|||||||
throughout the application, with proper integration to FastMCP's error system.
|
throughout the application, with proper integration to FastMCP's error system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from fastmcp.exceptions import ToolError as FastMCPToolError
|
from fastmcp.exceptions import ToolError as FastMCPToolError
|
||||||
|
|
||||||
|
|
||||||
@@ -19,36 +23,34 @@ class ToolError(FastMCPToolError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConfigurationError(ToolError):
|
@contextlib.contextmanager
|
||||||
"""Raised when there are configuration-related errors."""
|
def tool_error_handler(
|
||||||
|
tool_name: str,
|
||||||
|
action: str,
|
||||||
|
logger: logging.Logger,
|
||||||
|
) -> Iterator[None]:
|
||||||
|
"""Context manager that standardizes tool error handling.
|
||||||
|
|
||||||
pass
|
Re-raises ToolError as-is. Gives TimeoutError a descriptive message.
|
||||||
|
Catches all other exceptions, logs them with full traceback, and wraps them
|
||||||
|
in ToolError with a descriptive message.
|
||||||
|
|
||||||
|
Args:
|
||||||
class UnraidAPIError(ToolError):
|
tool_name: The tool name for error messages (e.g., "docker", "vm").
|
||||||
"""Raised when the Unraid API returns an error or is unreachable."""
|
action: The current action being executed.
|
||||||
|
logger: The logger instance to use for error logging.
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionError(ToolError):
|
|
||||||
"""Raised when there are WebSocket subscription-related errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(ToolError):
|
|
||||||
"""Raised when input validation fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IdempotentOperationError(ToolError):
|
|
||||||
"""Raised when an operation is idempotent (already in desired state).
|
|
||||||
|
|
||||||
This is used internally to signal that an operation was already complete,
|
|
||||||
which should typically be converted to a success response rather than
|
|
||||||
propagated as an error to the user.
|
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
pass
|
yield
|
||||||
|
except ToolError:
|
||||||
|
raise
|
||||||
|
except TimeoutError as e:
|
||||||
|
logger.exception(f"Timeout in unraid_{tool_name} action={action}: request exceeded time limit")
|
||||||
|
raise ToolError(
|
||||||
|
f"Request timed out executing {tool_name}/{action}. The Unraid API did not respond in time."
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in unraid_{tool_name} action={action}")
|
||||||
|
raise ToolError(
|
||||||
|
f"Failed to execute {tool_name}/{action}. Check server logs for details."
|
||||||
|
) from e
|
||||||
|
|||||||
@@ -9,38 +9,21 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(slots=True)
|
||||||
class SubscriptionData:
|
class SubscriptionData:
|
||||||
"""Container for subscription data with metadata."""
|
"""Container for subscription data with metadata.
|
||||||
|
|
||||||
|
Note: last_updated must be timezone-aware (use datetime.now(UTC)).
|
||||||
|
"""
|
||||||
|
|
||||||
data: dict[str, Any]
|
data: dict[str, Any]
|
||||||
last_updated: datetime
|
last_updated: datetime # Must be timezone-aware (UTC)
|
||||||
subscription_type: str
|
subscription_type: str
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
@dataclass
|
if self.last_updated.tzinfo is None:
|
||||||
class SystemHealth:
|
raise ValueError(
|
||||||
"""Container for system health status information."""
|
"last_updated must be timezone-aware; use datetime.now(UTC)"
|
||||||
|
)
|
||||||
is_healthy: bool
|
if not self.subscription_type.strip():
|
||||||
issues: list[str]
|
raise ValueError("subscription_type must be a non-empty string")
|
||||||
warnings: list[str]
|
|
||||||
last_checked: datetime
|
|
||||||
component_status: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class APIResponse:
|
|
||||||
"""Container for standardized API response data."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
error: str | None = None
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Type aliases for common data structures
|
|
||||||
ConfigValue = str | int | bool | float | None
|
|
||||||
ConfigDict = dict[str, ConfigValue]
|
|
||||||
GraphQLVariables = dict[str, Any]
|
|
||||||
HealthStatus = dict[str, str | bool | int | list[Any]]
|
|
||||||
|
|||||||
89
unraid_mcp/core/utils.py
Normal file
89
unraid_mcp/core/utils.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Shared utility functions for Unraid MCP tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
||||||
|
"""Safely traverse nested dict keys, handling None intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The root dictionary to traverse.
|
||||||
|
*keys: Sequence of keys to follow.
|
||||||
|
default: Value to return if any key is missing or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value at the end of the key chain, or default if unreachable.
|
||||||
|
Explicit ``None`` values at the final key also return ``default``.
|
||||||
|
"""
|
||||||
|
current = data
|
||||||
|
for key in keys:
|
||||||
|
if not isinstance(current, dict):
|
||||||
|
return default
|
||||||
|
current = current.get(key)
|
||||||
|
return current if current is not None else default
|
||||||
|
|
||||||
|
|
||||||
|
def format_bytes(bytes_value: int | None) -> str:
|
||||||
|
"""Format byte values into human-readable sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bytes_value: Number of bytes, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
|
||||||
|
"""
|
||||||
|
if bytes_value is None:
|
||||||
|
return "N/A"
|
||||||
|
try:
|
||||||
|
value = float(int(bytes_value))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return "N/A"
|
||||||
|
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
||||||
|
if value < 1024.0:
|
||||||
|
return f"{value:.2f} {unit}"
|
||||||
|
value /= 1024.0
|
||||||
|
return f"{value:.2f} EB"
|
||||||
|
|
||||||
|
|
||||||
|
def safe_display_url(url: str | None) -> str | None:
|
||||||
|
"""Return a redacted URL showing only scheme + host + port.
|
||||||
|
|
||||||
|
Strips path, query parameters, credentials, and fragments to avoid
|
||||||
|
leaking internal network topology or embedded secrets (CWE-200).
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname or "unknown"
|
||||||
|
if parsed.port:
|
||||||
|
return f"{parsed.scheme}://{host}:{parsed.port}"
|
||||||
|
return f"{parsed.scheme}://{host}"
|
||||||
|
except ValueError:
|
||||||
|
# urlparse raises ValueError for invalid URLs (e.g. contains control chars)
|
||||||
|
return "<unparseable>"
|
||||||
|
|
||||||
|
|
||||||
|
def format_kb(k: Any) -> str:
|
||||||
|
"""Format kilobyte values into human-readable sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k: Number of kilobytes, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable string like "1.00 GB" or "N/A" if input is None/invalid.
|
||||||
|
"""
|
||||||
|
if k is None:
|
||||||
|
return "N/A"
|
||||||
|
try:
|
||||||
|
k = int(k)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return "N/A"
|
||||||
|
if k >= 1024 * 1024 * 1024:
|
||||||
|
return f"{k / (1024 * 1024 * 1024):.2f} TB"
|
||||||
|
if k >= 1024 * 1024:
|
||||||
|
return f"{k / (1024 * 1024):.2f} GB"
|
||||||
|
if k >= 1024:
|
||||||
|
return f"{k / 1024:.2f} MB"
|
||||||
|
return f"{k:.2f} KB"
|
||||||
@@ -15,8 +15,11 @@ from .config.settings import (
|
|||||||
UNRAID_MCP_HOST,
|
UNRAID_MCP_HOST,
|
||||||
UNRAID_MCP_PORT,
|
UNRAID_MCP_PORT,
|
||||||
UNRAID_MCP_TRANSPORT,
|
UNRAID_MCP_TRANSPORT,
|
||||||
|
UNRAID_VERIFY_SSL,
|
||||||
VERSION,
|
VERSION,
|
||||||
|
validate_required_config,
|
||||||
)
|
)
|
||||||
|
from .subscriptions.diagnostics import register_diagnostic_tools
|
||||||
from .subscriptions.resources import register_subscription_resources
|
from .subscriptions.resources import register_subscription_resources
|
||||||
from .tools.array import register_array_tool
|
from .tools.array import register_array_tool
|
||||||
from .tools.docker import register_docker_tool
|
from .tools.docker import register_docker_tool
|
||||||
@@ -44,9 +47,10 @@ mcp = FastMCP(
|
|||||||
def register_all_modules() -> None:
|
def register_all_modules() -> None:
|
||||||
"""Register all tools and resources with the MCP instance."""
|
"""Register all tools and resources with the MCP instance."""
|
||||||
try:
|
try:
|
||||||
# Register subscription resources first
|
# Register subscription resources and diagnostic tools
|
||||||
register_subscription_resources(mcp)
|
register_subscription_resources(mcp)
|
||||||
logger.info("Subscription resources registered")
|
register_diagnostic_tools(mcp)
|
||||||
|
logger.info("Subscription resources and diagnostic tools registered")
|
||||||
|
|
||||||
# Register all consolidated tools
|
# Register all consolidated tools
|
||||||
registrars = [
|
registrars = [
|
||||||
@@ -73,6 +77,15 @@ def register_all_modules() -> None:
|
|||||||
|
|
||||||
def run_server() -> None:
|
def run_server() -> None:
|
||||||
"""Run the MCP server with the configured transport."""
|
"""Run the MCP server with the configured transport."""
|
||||||
|
# Validate required configuration before anything else
|
||||||
|
is_valid, missing = validate_required_config()
|
||||||
|
if not is_valid:
|
||||||
|
logger.critical(
|
||||||
|
f"Missing required configuration: {', '.join(missing)}. "
|
||||||
|
"Set these environment variables or add them to your .env file."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Log configuration
|
# Log configuration
|
||||||
if UNRAID_API_URL:
|
if UNRAID_API_URL:
|
||||||
logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...")
|
logger.info(f"UNRAID_API_URL loaded: {UNRAID_API_URL[:20]}...")
|
||||||
@@ -88,6 +101,13 @@ def run_server() -> None:
|
|||||||
logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}")
|
logger.info(f"UNRAID_MCP_HOST set to: {UNRAID_MCP_HOST}")
|
||||||
logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}")
|
logger.info(f"UNRAID_MCP_TRANSPORT set to: {UNRAID_MCP_TRANSPORT}")
|
||||||
|
|
||||||
|
if UNRAID_VERIFY_SSL is False:
|
||||||
|
logger.warning(
|
||||||
|
"SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). "
|
||||||
|
"Connections to Unraid API are vulnerable to man-in-the-middle attacks. "
|
||||||
|
"Only use this in trusted networks or for development."
|
||||||
|
)
|
||||||
|
|
||||||
# Register all modules
|
# Register all modules
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ development and debugging purposes.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
import re
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
@@ -17,9 +19,63 @@ from websockets.typing import Subprotocol
|
|||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError
|
||||||
|
from ..core.utils import safe_display_url
|
||||||
from .manager import subscription_manager
|
from .manager import subscription_manager
|
||||||
from .resources import ensure_subscriptions_started
|
from .resources import ensure_subscriptions_started
|
||||||
from .utils import build_ws_ssl_context
|
from .utils import build_ws_ssl_context, build_ws_url
|
||||||
|
|
||||||
|
|
||||||
|
_ALLOWED_SUBSCRIPTION_NAMES = frozenset(
|
||||||
|
{
|
||||||
|
"logFileSubscription",
|
||||||
|
"containerStatsSubscription",
|
||||||
|
"cpuSubscription",
|
||||||
|
"memorySubscription",
|
||||||
|
"arraySubscription",
|
||||||
|
"networkSubscription",
|
||||||
|
"dockerSubscription",
|
||||||
|
"vmSubscription",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern: must start with "subscription" and contain only a known subscription name.
|
||||||
|
# _FORBIDDEN_KEYWORDS rejects any query that contains standalone "mutation" or "query"
|
||||||
|
# as distinct words. Word boundaries (\b) ensure "mutationField"-style identifiers are
|
||||||
|
# not rejected — only bare "mutation" or "query" operation keywords are blocked.
|
||||||
|
_SUBSCRIPTION_NAME_PATTERN = re.compile(r"^\s*subscription\b[^{]*\{\s*(\w+)", re.IGNORECASE)
|
||||||
|
_FORBIDDEN_KEYWORDS = re.compile(r"\b(mutation|query)\b", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_subscription_query(query: str) -> str:
|
||||||
|
"""Validate that a subscription query is safe to execute.
|
||||||
|
|
||||||
|
Only allows subscription operations targeting whitelisted subscription names.
|
||||||
|
Rejects any query containing mutation/query keywords.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted subscription name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolError: If the query fails validation.
|
||||||
|
"""
|
||||||
|
if _FORBIDDEN_KEYWORDS.search(query):
|
||||||
|
raise ToolError("Query rejected: must be a subscription, not a mutation or query.")
|
||||||
|
|
||||||
|
match = _SUBSCRIPTION_NAME_PATTERN.match(query)
|
||||||
|
if not match:
|
||||||
|
raise ToolError(
|
||||||
|
"Query rejected: must start with 'subscription' and contain a valid "
|
||||||
|
"subscription operation. Example: subscription { logFileSubscription { ... } }"
|
||||||
|
)
|
||||||
|
|
||||||
|
sub_name = match.group(1)
|
||||||
|
if sub_name not in _ALLOWED_SUBSCRIPTION_NAMES:
|
||||||
|
raise ToolError(
|
||||||
|
f"Subscription '{sub_name}' is not allowed. "
|
||||||
|
f"Allowed subscriptions: {sorted(_ALLOWED_SUBSCRIPTION_NAMES)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sub_name
|
||||||
|
|
||||||
|
|
||||||
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
def register_diagnostic_tools(mcp: FastMCP) -> None:
|
||||||
@@ -34,6 +90,10 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
"""Test a GraphQL subscription query directly to debug schema issues.
|
"""Test a GraphQL subscription query directly to debug schema issues.
|
||||||
|
|
||||||
Use this to find working subscription field names and structure.
|
Use this to find working subscription field names and structure.
|
||||||
|
Only whitelisted subscriptions are allowed (logFileSubscription,
|
||||||
|
containerStatsSubscription, cpuSubscription, memorySubscription,
|
||||||
|
arraySubscription, networkSubscription, dockerSubscription,
|
||||||
|
vmSubscription).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subscription_query: The GraphQL subscription query to test
|
subscription_query: The GraphQL subscription query to test
|
||||||
@@ -41,16 +101,16 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing test results and response data
|
Dict containing test results and response data
|
||||||
"""
|
"""
|
||||||
try:
|
# Validate before any network I/O
|
||||||
logger.info(f"[TEST_SUBSCRIPTION] Testing query: {subscription_query}")
|
sub_name = _validate_subscription_query(subscription_query)
|
||||||
|
|
||||||
# Build WebSocket URL
|
try:
|
||||||
if not UNRAID_API_URL:
|
logger.info(f"[TEST_SUBSCRIPTION] Testing validated subscription '{sub_name}'")
|
||||||
raise ToolError("UNRAID_API_URL is not configured")
|
|
||||||
ws_url = (
|
try:
|
||||||
UNRAID_API_URL.replace("https://", "wss://").replace("http://", "ws://")
|
ws_url = build_ws_url()
|
||||||
+ "/graphql"
|
except ValueError as e:
|
||||||
)
|
raise ToolError(str(e)) from e
|
||||||
|
|
||||||
ssl_context = build_ws_ssl_context(ws_url)
|
ssl_context = build_ws_ssl_context(ws_url)
|
||||||
|
|
||||||
@@ -59,6 +119,7 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
ws_url,
|
ws_url,
|
||||||
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||||
ssl=ssl_context,
|
ssl=ssl_context,
|
||||||
|
open_timeout=10,
|
||||||
ping_interval=30,
|
ping_interval=30,
|
||||||
ping_timeout=10,
|
ping_timeout=10,
|
||||||
) as websocket:
|
) as websocket:
|
||||||
@@ -102,6 +163,8 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
"note": "Connection successful, subscription may be waiting for events",
|
"note": "Connection successful, subscription may be waiting for events",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except ToolError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
|
logger.error(f"[TEST_SUBSCRIPTION] Error: {e}", exc_info=True)
|
||||||
return {"error": str(e), "query_tested": subscription_query}
|
return {"error": str(e), "query_tested": subscription_query}
|
||||||
@@ -122,18 +185,18 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
logger.info("[DIAGNOSTIC] Running subscription diagnostics...")
|
logger.info("[DIAGNOSTIC] Running subscription diagnostics...")
|
||||||
|
|
||||||
# Get comprehensive status
|
# Get comprehensive status
|
||||||
status = subscription_manager.get_subscription_status()
|
status = await subscription_manager.get_subscription_status()
|
||||||
|
|
||||||
# Initialize connection issues list with proper type
|
# Initialize connection issues list with proper type
|
||||||
connection_issues: list[dict[str, Any]] = []
|
connection_issues: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# Add environment info with explicit typing
|
# Add environment info with explicit typing
|
||||||
diagnostic_info: dict[str, Any] = {
|
diagnostic_info: dict[str, Any] = {
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
"environment": {
|
"environment": {
|
||||||
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
||||||
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
"max_reconnect_attempts": subscription_manager.max_reconnect_attempts,
|
||||||
"unraid_api_url": UNRAID_API_URL[:50] + "..." if UNRAID_API_URL else None,
|
"unraid_api_url": safe_display_url(UNRAID_API_URL),
|
||||||
"api_key_configured": bool(UNRAID_API_KEY),
|
"api_key_configured": bool(UNRAID_API_KEY),
|
||||||
"websocket_url": None,
|
"websocket_url": None,
|
||||||
},
|
},
|
||||||
@@ -152,17 +215,9 @@ def register_diagnostic_tools(mcp: FastMCP) -> None:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Calculate WebSocket URL
|
# Calculate WebSocket URL (stays None if UNRAID_API_URL not configured)
|
||||||
if UNRAID_API_URL:
|
with contextlib.suppress(ValueError):
|
||||||
if UNRAID_API_URL.startswith("https://"):
|
diagnostic_info["environment"]["websocket_url"] = build_ws_url()
|
||||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
|
||||||
elif UNRAID_API_URL.startswith("http://"):
|
|
||||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
|
||||||
else:
|
|
||||||
ws_url = UNRAID_API_URL
|
|
||||||
if not ws_url.endswith("/graphql"):
|
|
||||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
|
||||||
diagnostic_info["environment"]["websocket_url"] = ws_url
|
|
||||||
|
|
||||||
# Analyze issues
|
# Analyze issues
|
||||||
for sub_name, sub_status in status.items():
|
for sub_name, sub_status in status.items():
|
||||||
|
|||||||
@@ -8,16 +8,74 @@ error handling, reconnection logic, and authentication.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
from websockets.typing import Subprotocol
|
from websockets.typing import Subprotocol
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..config.settings import UNRAID_API_KEY, UNRAID_API_URL
|
from ..config.settings import UNRAID_API_KEY
|
||||||
|
from ..core.client import redact_sensitive
|
||||||
from ..core.types import SubscriptionData
|
from ..core.types import SubscriptionData
|
||||||
from .utils import build_ws_ssl_context
|
from .utils import build_ws_ssl_context, build_ws_url
|
||||||
|
|
||||||
|
|
||||||
|
# Resource data size limits to prevent unbounded memory growth
|
||||||
|
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1MB
|
||||||
|
_MAX_RESOURCE_DATA_LINES = 5_000
|
||||||
|
# Minimum stable connection duration (seconds) before resetting reconnect counter
|
||||||
|
_STABLE_CONNECTION_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Cap log content in subscription data to prevent unbounded memory growth.
|
||||||
|
|
||||||
|
Returns a new dict — does NOT mutate the input. If any nested 'content'
|
||||||
|
field (from log subscriptions) exceeds the byte limit, truncate it to the
|
||||||
|
most recent _MAX_RESOURCE_DATA_LINES lines.
|
||||||
|
|
||||||
|
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
result[key] = _cap_log_content(value)
|
||||||
|
elif (
|
||||||
|
key == "content"
|
||||||
|
and isinstance(value, str)
|
||||||
|
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
|
||||||
|
):
|
||||||
|
lines = value.splitlines()
|
||||||
|
original_line_count = len(lines)
|
||||||
|
|
||||||
|
# Keep most recent lines first.
|
||||||
|
if len(lines) > _MAX_RESOURCE_DATA_LINES:
|
||||||
|
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
|
||||||
|
|
||||||
|
# Enforce byte cap while preserving whole-line boundaries where possible.
|
||||||
|
truncated = "\n".join(lines)
|
||||||
|
truncated_bytes = truncated.encode("utf-8", errors="replace")
|
||||||
|
while len(lines) > 1 and len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
|
||||||
|
lines = lines[1:]
|
||||||
|
truncated = "\n".join(lines)
|
||||||
|
truncated_bytes = truncated.encode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# Last resort: if a single line still exceeds cap, hard-cap bytes.
|
||||||
|
if len(truncated_bytes) > _MAX_RESOURCE_DATA_BYTES:
|
||||||
|
truncated = truncated_bytes[-_MAX_RESOURCE_DATA_BYTES :].decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"[RESOURCE] Capped log content from {original_line_count} to "
|
||||||
|
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
|
||||||
|
)
|
||||||
|
result[key] = truncated
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionManager:
|
class SubscriptionManager:
|
||||||
@@ -26,7 +84,6 @@ class SubscriptionManager:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
|
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
|
||||||
self.resource_data: dict[str, SubscriptionData] = {}
|
self.resource_data: dict[str, SubscriptionData] = {}
|
||||||
self.websocket: websockets.WebSocketServerProtocol | None = None
|
|
||||||
self.subscription_lock = asyncio.Lock()
|
self.subscription_lock = asyncio.Lock()
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
@@ -37,6 +94,7 @@ class SubscriptionManager:
|
|||||||
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
|
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
|
||||||
self.connection_states: dict[str, str] = {} # Track connection state per subscription
|
self.connection_states: dict[str, str] = {} # Track connection state per subscription
|
||||||
self.last_error: dict[str, str] = {} # Track last error per subscription
|
self.last_error: dict[str, str] = {} # Track last error per subscription
|
||||||
|
self._connection_start_times: dict[str, float] = {} # Track when connections started
|
||||||
|
|
||||||
# Define subscription configurations
|
# Define subscription configurations
|
||||||
self.subscription_configs = {
|
self.subscription_configs = {
|
||||||
@@ -105,6 +163,7 @@ class SubscriptionManager:
|
|||||||
# Reset connection tracking
|
# Reset connection tracking
|
||||||
self.reconnect_attempts[subscription_name] = 0
|
self.reconnect_attempts[subscription_name] = 0
|
||||||
self.connection_states[subscription_name] = "starting"
|
self.connection_states[subscription_name] = "starting"
|
||||||
|
self._connection_start_times.pop(subscription_name, None)
|
||||||
|
|
||||||
async with self.subscription_lock:
|
async with self.subscription_lock:
|
||||||
try:
|
try:
|
||||||
@@ -138,6 +197,7 @@ class SubscriptionManager:
|
|||||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
|
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
|
||||||
del self.active_subscriptions[subscription_name]
|
del self.active_subscriptions[subscription_name]
|
||||||
self.connection_states[subscription_name] = "stopped"
|
self.connection_states[subscription_name] = "stopped"
|
||||||
|
self._connection_start_times.pop(subscription_name, None)
|
||||||
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
|
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
|
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
|
||||||
@@ -165,20 +225,7 @@ class SubscriptionManager:
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build WebSocket URL with detailed logging
|
ws_url = build_ws_url()
|
||||||
if not UNRAID_API_URL:
|
|
||||||
raise ValueError("UNRAID_API_URL is not configured")
|
|
||||||
|
|
||||||
if UNRAID_API_URL.startswith("https://"):
|
|
||||||
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
|
||||||
elif UNRAID_API_URL.startswith("http://"):
|
|
||||||
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
|
||||||
else:
|
|
||||||
ws_url = UNRAID_API_URL
|
|
||||||
|
|
||||||
if not ws_url.endswith("/graphql"):
|
|
||||||
ws_url = ws_url.rstrip("/") + "/graphql"
|
|
||||||
|
|
||||||
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
|
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}"
|
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}"
|
||||||
@@ -195,6 +242,7 @@ class SubscriptionManager:
|
|||||||
async with websockets.connect(
|
async with websockets.connect(
|
||||||
ws_url,
|
ws_url,
|
||||||
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
|
||||||
|
open_timeout=connect_timeout,
|
||||||
ping_interval=20,
|
ping_interval=20,
|
||||||
ping_timeout=10,
|
ping_timeout=10,
|
||||||
close_timeout=10,
|
close_timeout=10,
|
||||||
@@ -206,9 +254,9 @@ class SubscriptionManager:
|
|||||||
)
|
)
|
||||||
self.connection_states[subscription_name] = "connected"
|
self.connection_states[subscription_name] = "connected"
|
||||||
|
|
||||||
# Reset retry count on successful connection
|
# Track connection start time — only reset retry counter
|
||||||
self.reconnect_attempts[subscription_name] = 0
|
# after the connection proves stable (>30s connected)
|
||||||
retry_delay = 5 # Reset delay
|
self._connection_start_times[subscription_name] = time.monotonic()
|
||||||
|
|
||||||
# Initialize GraphQL-WS protocol
|
# Initialize GraphQL-WS protocol
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -290,7 +338,9 @@ class SubscriptionManager:
|
|||||||
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
|
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
|
||||||
)
|
)
|
||||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
|
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
|
||||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Variables: {variables}")
|
logger.debug(
|
||||||
|
f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}"
|
||||||
|
)
|
||||||
|
|
||||||
await websocket.send(json.dumps(subscription_message))
|
await websocket.send(json.dumps(subscription_message))
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -326,11 +376,18 @@ class SubscriptionManager:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"[DATA:{subscription_name}] Received subscription data update"
|
f"[DATA:{subscription_name}] Received subscription data update"
|
||||||
)
|
)
|
||||||
self.resource_data[subscription_name] = SubscriptionData(
|
capped_data = (
|
||||||
data=payload["data"],
|
_cap_log_content(payload["data"])
|
||||||
last_updated=datetime.now(),
|
if isinstance(payload["data"], dict)
|
||||||
|
else payload["data"]
|
||||||
|
)
|
||||||
|
new_entry = SubscriptionData(
|
||||||
|
data=capped_data,
|
||||||
|
last_updated=datetime.now(UTC),
|
||||||
subscription_type=subscription_name,
|
subscription_type=subscription_name,
|
||||||
)
|
)
|
||||||
|
async with self.subscription_lock:
|
||||||
|
self.resource_data[subscription_name] = new_entry
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
|
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
|
||||||
)
|
)
|
||||||
@@ -391,7 +448,8 @@ class SubscriptionManager:
|
|||||||
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
|
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[DATA:{subscription_name}] Error processing message: {e}"
|
f"[DATA:{subscription_name}] Error processing message: {e}",
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
msg_preview = (
|
msg_preview = (
|
||||||
message[:200]
|
message[:200]
|
||||||
@@ -421,11 +479,39 @@ class SubscriptionManager:
|
|||||||
self.connection_states[subscription_name] = "invalid_uri"
|
self.connection_states[subscription_name] = "invalid_uri"
|
||||||
break # Don't retry on invalid URI
|
break # Don't retry on invalid URI
|
||||||
|
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
error_msg = f"Unexpected error: {e}"
|
# Non-retryable configuration error (e.g. UNRAID_API_URL not set)
|
||||||
|
error_msg = f"Configuration error: {e}"
|
||||||
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
|
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
|
||||||
self.last_error[subscription_name] = error_msg
|
self.last_error[subscription_name] = error_msg
|
||||||
self.connection_states[subscription_name] = "error"
|
self.connection_states[subscription_name] = "error"
|
||||||
|
break # Don't retry on configuration errors
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error: {e}"
|
||||||
|
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True)
|
||||||
|
self.last_error[subscription_name] = error_msg
|
||||||
|
self.connection_states[subscription_name] = "error"
|
||||||
|
|
||||||
|
# Check if connection was stable before deciding on retry behavior
|
||||||
|
start_time = self._connection_start_times.pop(subscription_name, None)
|
||||||
|
if start_time is not None:
|
||||||
|
connected_duration = time.monotonic() - start_time
|
||||||
|
if connected_duration >= _STABLE_CONNECTION_SECONDS:
|
||||||
|
# Connection was stable — reset retry counter and backoff
|
||||||
|
logger.info(
|
||||||
|
f"[WEBSOCKET:{subscription_name}] Connection was stable "
|
||||||
|
f"({connected_duration:.0f}s >= {_STABLE_CONNECTION_SECONDS}s), "
|
||||||
|
f"resetting retry counter"
|
||||||
|
)
|
||||||
|
self.reconnect_attempts[subscription_name] = 0
|
||||||
|
retry_delay = 5
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"[WEBSOCKET:{subscription_name}] Connection was unstable "
|
||||||
|
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
|
||||||
|
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate backoff delay
|
# Calculate backoff delay
|
||||||
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
||||||
@@ -435,15 +521,26 @@ class SubscriptionManager:
|
|||||||
self.connection_states[subscription_name] = "reconnecting"
|
self.connection_states[subscription_name] = "reconnecting"
|
||||||
await asyncio.sleep(retry_delay)
|
await asyncio.sleep(retry_delay)
|
||||||
|
|
||||||
def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
# The while loop exited (via break or max_retries exceeded).
|
||||||
|
# Remove from active_subscriptions so start_subscription() can restart it.
|
||||||
|
async with self.subscription_lock:
|
||||||
|
self.active_subscriptions.pop(subscription_name, None)
|
||||||
|
logger.info(
|
||||||
|
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
|
||||||
|
f"removed from active_subscriptions. Final state: "
|
||||||
|
f"{self.connection_states.get(subscription_name, 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
|
||||||
"""Get current resource data with enhanced logging."""
|
"""Get current resource data with enhanced logging."""
|
||||||
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
|
||||||
|
|
||||||
if resource_name in self.resource_data:
|
async with self.subscription_lock:
|
||||||
data = self.resource_data[resource_name]
|
if resource_name in self.resource_data:
|
||||||
age_seconds = (datetime.now() - data.last_updated).total_seconds()
|
data = self.resource_data[resource_name]
|
||||||
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
|
age_seconds = (datetime.now(UTC) - data.last_updated).total_seconds()
|
||||||
return data.data
|
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
|
||||||
|
return data.data
|
||||||
logger.debug(f"[RESOURCE:{resource_name}] No data available")
|
logger.debug(f"[RESOURCE:{resource_name}] No data available")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -453,38 +550,39 @@ class SubscriptionManager:
|
|||||||
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
|
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
|
||||||
return active
|
return active
|
||||||
|
|
||||||
def get_subscription_status(self) -> dict[str, dict[str, Any]]:
|
async def get_subscription_status(self) -> dict[str, dict[str, Any]]:
|
||||||
"""Get detailed status of all subscriptions for diagnostics."""
|
"""Get detailed status of all subscriptions for diagnostics."""
|
||||||
status = {}
|
status = {}
|
||||||
|
|
||||||
for sub_name, config in self.subscription_configs.items():
|
async with self.subscription_lock:
|
||||||
sub_status = {
|
for sub_name, config in self.subscription_configs.items():
|
||||||
"config": {
|
sub_status = {
|
||||||
"resource": config["resource"],
|
"config": {
|
||||||
"description": config["description"],
|
"resource": config["resource"],
|
||||||
"auto_start": config.get("auto_start", False),
|
"description": config["description"],
|
||||||
},
|
"auto_start": config.get("auto_start", False),
|
||||||
"runtime": {
|
},
|
||||||
"active": sub_name in self.active_subscriptions,
|
"runtime": {
|
||||||
"connection_state": self.connection_states.get(sub_name, "not_started"),
|
"active": sub_name in self.active_subscriptions,
|
||||||
"reconnect_attempts": self.reconnect_attempts.get(sub_name, 0),
|
"connection_state": self.connection_states.get(sub_name, "not_started"),
|
||||||
"last_error": self.last_error.get(sub_name, None),
|
"reconnect_attempts": self.reconnect_attempts.get(sub_name, 0),
|
||||||
},
|
"last_error": self.last_error.get(sub_name, None),
|
||||||
}
|
},
|
||||||
|
|
||||||
# Add data info if available
|
|
||||||
if sub_name in self.resource_data:
|
|
||||||
data_info = self.resource_data[sub_name]
|
|
||||||
age_seconds = (datetime.now() - data_info.last_updated).total_seconds()
|
|
||||||
sub_status["data"] = {
|
|
||||||
"available": True,
|
|
||||||
"last_updated": data_info.last_updated.isoformat(),
|
|
||||||
"age_seconds": age_seconds,
|
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
sub_status["data"] = {"available": False}
|
|
||||||
|
|
||||||
status[sub_name] = sub_status
|
# Add data info if available
|
||||||
|
if sub_name in self.resource_data:
|
||||||
|
data_info = self.resource_data[sub_name]
|
||||||
|
age_seconds = (datetime.now(UTC) - data_info.last_updated).total_seconds()
|
||||||
|
sub_status["data"] = {
|
||||||
|
"available": True,
|
||||||
|
"last_updated": data_info.last_updated.isoformat(),
|
||||||
|
"age_seconds": age_seconds,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
sub_status["data"] = {"available": False}
|
||||||
|
|
||||||
|
status[sub_name] = sub_status
|
||||||
|
|
||||||
logger.debug(f"[SUBSCRIPTION_MANAGER] Generated status for {len(status)} subscriptions")
|
logger.debug(f"[SUBSCRIPTION_MANAGER] Generated status for {len(status)} subscriptions")
|
||||||
return status
|
return status
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ async def autostart_subscriptions() -> None:
|
|||||||
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
logger.info("[AUTOSTART] Auto-start process completed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
|
logger.error(f"[AUTOSTART] Failed during auto-start process: {e}", exc_info=True)
|
||||||
|
raise # Propagate so ensure_subscriptions_started doesn't mark as started
|
||||||
|
|
||||||
# Optional log file subscription
|
# Optional log file subscription
|
||||||
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")
|
log_path = os.getenv("UNRAID_AUTOSTART_LOG_PATH")
|
||||||
@@ -82,7 +83,7 @@ def register_subscription_resources(mcp: FastMCP) -> None:
|
|||||||
async def logs_stream_resource() -> str:
|
async def logs_stream_resource() -> str:
|
||||||
"""Real-time log stream data from subscription."""
|
"""Real-time log stream data from subscription."""
|
||||||
await ensure_subscriptions_started()
|
await ensure_subscriptions_started()
|
||||||
data = subscription_manager.get_resource_data("logFileSubscription")
|
data = await subscription_manager.get_resource_data("logFileSubscription")
|
||||||
if data:
|
if data:
|
||||||
return json.dumps(data, indent=2)
|
return json.dumps(data, indent=2)
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
|
|||||||
@@ -2,7 +2,34 @@
|
|||||||
|
|
||||||
import ssl as _ssl
|
import ssl as _ssl
|
||||||
|
|
||||||
from ..config.settings import UNRAID_VERIFY_SSL
|
from ..config.settings import UNRAID_API_URL, UNRAID_VERIFY_SSL
|
||||||
|
|
||||||
|
|
||||||
|
def build_ws_url() -> str:
|
||||||
|
"""Build a WebSocket URL from the configured UNRAID_API_URL.
|
||||||
|
|
||||||
|
Converts http(s) scheme to ws(s) and ensures /graphql path suffix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WebSocket URL string (e.g. "wss://10.1.0.2:31337/graphql").
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If UNRAID_API_URL is not configured.
|
||||||
|
"""
|
||||||
|
if not UNRAID_API_URL:
|
||||||
|
raise ValueError("UNRAID_API_URL is not configured")
|
||||||
|
|
||||||
|
if UNRAID_API_URL.startswith("https://"):
|
||||||
|
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
|
||||||
|
elif UNRAID_API_URL.startswith("http://"):
|
||||||
|
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
|
||||||
|
else:
|
||||||
|
ws_url = UNRAID_API_URL
|
||||||
|
|
||||||
|
if not ws_url.endswith("/graphql"):
|
||||||
|
ws_url = ws_url.rstrip("/") + "/graphql"
|
||||||
|
|
||||||
|
return ws_url
|
||||||
|
|
||||||
|
|
||||||
def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
|
def build_ws_ssl_context(ws_url: str) -> _ssl.SSLContext | None:
|
||||||
|
|||||||
@@ -3,13 +3,13 @@
|
|||||||
Provides the `unraid_array` tool with 5 actions for parity check management.
|
Provides the `unraid_array` tool with 5 actions for parity check management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -53,6 +53,14 @@ ARRAY_ACTIONS = Literal[
|
|||||||
"parity_status",
|
"parity_status",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(ARRAY_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(ARRAY_ACTIONS))
|
||||||
|
_extra = set(get_args(ARRAY_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ARRAY_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_array_tool(mcp: FastMCP) -> None:
|
def register_array_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_array tool with the FastMCP instance."""
|
"""Register the unraid_array tool with the FastMCP instance."""
|
||||||
@@ -74,7 +82,7 @@ def register_array_tool(mcp: FastMCP) -> None:
|
|||||||
if action not in ALL_ACTIONS:
|
if action not in ALL_ACTIONS:
|
||||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
|
|
||||||
try:
|
with tool_error_handler("array", action, logger):
|
||||||
logger.info(f"Executing unraid_array action={action}")
|
logger.info(f"Executing unraid_array action={action}")
|
||||||
|
|
||||||
if action in QUERIES:
|
if action in QUERIES:
|
||||||
@@ -95,10 +103,4 @@ def register_array_tool(mcp: FastMCP) -> None:
|
|||||||
"data": data,
|
"data": data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_array action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute array/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Array tool registered successfully")
|
logger.info("Array tool registered successfully")
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ logs, networks, and update management.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
from ..core.utils import safe_get
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -98,7 +99,10 @@ MUTATIONS: dict[str, str] = {
|
|||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
DESTRUCTIVE_ACTIONS = {"remove"}
|
DESTRUCTIVE_ACTIONS = {"remove", "update_all"}
|
||||||
|
# NOTE (Code-M-07): "details" and "logs" are listed here because they require a
|
||||||
|
# container_id parameter, but unlike mutations they use fuzzy name matching (not
|
||||||
|
# strict). This is intentional: read-only queries are safe with fuzzy matching.
|
||||||
_ACTIONS_REQUIRING_CONTAINER_ID = {
|
_ACTIONS_REQUIRING_CONTAINER_ID = {
|
||||||
"start",
|
"start",
|
||||||
"stop",
|
"stop",
|
||||||
@@ -111,6 +115,7 @@ _ACTIONS_REQUIRING_CONTAINER_ID = {
|
|||||||
"logs",
|
"logs",
|
||||||
}
|
}
|
||||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"restart"}
|
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"restart"}
|
||||||
|
_MAX_TAIL_LINES = 10_000
|
||||||
|
|
||||||
DOCKER_ACTIONS = Literal[
|
DOCKER_ACTIONS = Literal[
|
||||||
"list",
|
"list",
|
||||||
@@ -130,33 +135,36 @@ DOCKER_ACTIONS = Literal[
|
|||||||
"check_updates",
|
"check_updates",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Docker container IDs: 64 hex chars + optional suffix (e.g., ":local")
|
if set(get_args(DOCKER_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(DOCKER_ACTIONS))
|
||||||
|
_extra = set(get_args(DOCKER_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"DOCKER_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Full PrefixedID: 64 hex chars + optional suffix (e.g., ":local")
|
||||||
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
|
_DOCKER_ID_PATTERN = re.compile(r"^[a-f0-9]{64}(:[a-z0-9]+)?$", re.IGNORECASE)
|
||||||
|
|
||||||
|
# Short hex prefix: at least 12 hex chars (standard Docker short ID length)
|
||||||
def _safe_get(data: dict[str, Any], *keys: str, default: Any = None) -> Any:
|
_DOCKER_SHORT_ID_PATTERN = re.compile(r"^[a-f0-9]{12,63}$", re.IGNORECASE)
|
||||||
"""Safely traverse nested dict keys, handling None intermediates."""
|
|
||||||
current = data
|
|
||||||
for key in keys:
|
|
||||||
if not isinstance(current, dict):
|
|
||||||
return default
|
|
||||||
current = current.get(key)
|
|
||||||
return current if current is not None else default
|
|
||||||
|
|
||||||
|
|
||||||
def find_container_by_identifier(
|
def find_container_by_identifier(
|
||||||
identifier: str, containers: list[dict[str, Any]]
|
identifier: str, containers: list[dict[str, Any]], *, strict: bool = False
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Find a container by ID or name with fuzzy matching.
|
"""Find a container by ID or name with optional fuzzy matching.
|
||||||
|
|
||||||
Match priority:
|
Match priority:
|
||||||
1. Exact ID match
|
1. Exact ID match
|
||||||
2. Exact name match (case-sensitive)
|
2. Exact name match (case-sensitive)
|
||||||
|
|
||||||
|
When strict=False (default), also tries:
|
||||||
3. Name starts with identifier (case-insensitive)
|
3. Name starts with identifier (case-insensitive)
|
||||||
4. Name contains identifier as substring (case-insensitive)
|
4. Name contains identifier as substring (case-insensitive)
|
||||||
|
|
||||||
Note: Short identifiers (e.g. "db") may match unintended containers
|
When strict=True, only exact matches (1 & 2) are used.
|
||||||
via substring. Use more specific names or IDs for precision.
|
Use strict=True for mutations to prevent targeting the wrong container.
|
||||||
"""
|
"""
|
||||||
if not containers:
|
if not containers:
|
||||||
return None
|
return None
|
||||||
@@ -168,20 +176,24 @@ def find_container_by_identifier(
|
|||||||
if identifier in c.get("names", []):
|
if identifier in c.get("names", []):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
# Strict mode: no fuzzy matching allowed
|
||||||
|
if strict:
|
||||||
|
return None
|
||||||
|
|
||||||
id_lower = identifier.lower()
|
id_lower = identifier.lower()
|
||||||
|
|
||||||
# Priority 3: prefix match (more precise than substring)
|
# Priority 3: prefix match (more precise than substring)
|
||||||
for c in containers:
|
for c in containers:
|
||||||
for name in c.get("names", []):
|
for name in c.get("names", []):
|
||||||
if name.lower().startswith(id_lower):
|
if name.lower().startswith(id_lower):
|
||||||
logger.info(f"Prefix match: '{identifier}' -> '{name}'")
|
logger.debug(f"Prefix match: '{identifier}' -> '{name}'")
|
||||||
return c
|
return c
|
||||||
|
|
||||||
# Priority 4: substring match (least precise)
|
# Priority 4: substring match (least precise)
|
||||||
for c in containers:
|
for c in containers:
|
||||||
for name in c.get("names", []):
|
for name in c.get("names", []):
|
||||||
if id_lower in name.lower():
|
if id_lower in name.lower():
|
||||||
logger.info(f"Substring match: '{identifier}' -> '{name}'")
|
logger.debug(f"Substring match: '{identifier}' -> '{name}'")
|
||||||
return c
|
return c
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -195,27 +207,66 @@ def get_available_container_names(containers: list[dict[str, Any]]) -> list[str]
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_container_id(container_id: str) -> str:
|
async def _resolve_container_id(container_id: str, *, strict: bool = False) -> str:
|
||||||
"""Resolve a container name/identifier to its actual PrefixedID."""
|
"""Resolve a container name/identifier to its actual PrefixedID.
|
||||||
|
|
||||||
|
Optimization: if the identifier is a full 64-char hex ID (with optional
|
||||||
|
:suffix), skip the container list fetch entirely and use it directly.
|
||||||
|
If it's a short hex prefix (12-63 chars), fetch the list and match by
|
||||||
|
ID prefix. Only fetch the container list for name-based lookups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: Container name or ID to resolve
|
||||||
|
strict: When True, only exact name/ID matches are allowed (no fuzzy).
|
||||||
|
Use for mutations to prevent targeting the wrong container.
|
||||||
|
"""
|
||||||
|
# Full PrefixedID: skip the list fetch entirely
|
||||||
if _DOCKER_ID_PATTERN.match(container_id):
|
if _DOCKER_ID_PATTERN.match(container_id):
|
||||||
return container_id
|
return container_id
|
||||||
|
|
||||||
logger.info(f"Resolving container identifier '{container_id}'")
|
logger.info(f"Resolving container identifier '{container_id}' (strict={strict})")
|
||||||
list_query = """
|
list_query = """
|
||||||
query ResolveContainerID {
|
query ResolveContainerID {
|
||||||
docker { containers(skipCache: true) { id names } }
|
docker { containers(skipCache: true) { id names } }
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
data = await make_graphql_request(list_query)
|
data = await make_graphql_request(list_query)
|
||||||
containers = _safe_get(data, "docker", "containers", default=[])
|
containers = safe_get(data, "docker", "containers", default=[])
|
||||||
resolved = find_container_by_identifier(container_id, containers)
|
|
||||||
|
# Short hex prefix: match by ID prefix before trying name matching
|
||||||
|
if _DOCKER_SHORT_ID_PATTERN.match(container_id):
|
||||||
|
id_lower = container_id.lower()
|
||||||
|
matches: list[dict[str, Any]] = []
|
||||||
|
for c in containers:
|
||||||
|
cid = (c.get("id") or "").lower()
|
||||||
|
if cid.startswith(id_lower) or cid.split(":")[0].startswith(id_lower):
|
||||||
|
matches.append(c)
|
||||||
|
if len(matches) == 1:
|
||||||
|
actual_id = str(matches[0].get("id", ""))
|
||||||
|
logger.info(f"Resolved short ID '{container_id}' -> '{actual_id}'")
|
||||||
|
return actual_id
|
||||||
|
if len(matches) > 1:
|
||||||
|
candidate_ids = [str(c.get("id", "")) for c in matches[:5]]
|
||||||
|
raise ToolError(
|
||||||
|
f"Short container ID prefix '{container_id}' is ambiguous. "
|
||||||
|
f"Matches: {', '.join(candidate_ids)}. Use a longer ID or exact name."
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = find_container_by_identifier(container_id, containers, strict=strict)
|
||||||
if resolved:
|
if resolved:
|
||||||
actual_id = str(resolved.get("id", ""))
|
actual_id = str(resolved.get("id", ""))
|
||||||
logger.info(f"Resolved '{container_id}' -> '{actual_id}'")
|
logger.info(f"Resolved '{container_id}' -> '{actual_id}'")
|
||||||
return actual_id
|
return actual_id
|
||||||
|
|
||||||
available = get_available_container_names(containers)
|
available = get_available_container_names(containers)
|
||||||
msg = f"Container '{container_id}' not found."
|
if strict:
|
||||||
|
msg = (
|
||||||
|
f"Container '{container_id}' not found by exact match. "
|
||||||
|
f"Mutations require an exact container name or full ID — "
|
||||||
|
f"fuzzy/substring matching is not allowed for safety."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
msg = f"Container '{container_id}' not found."
|
||||||
if available:
|
if available:
|
||||||
msg += f" Available: {', '.join(available[:10])}"
|
msg += f" Available: {', '.join(available[:10])}"
|
||||||
raise ToolError(msg)
|
raise ToolError(msg)
|
||||||
@@ -264,56 +315,58 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
|||||||
if action == "network_details" and not network_id:
|
if action == "network_details" and not network_id:
|
||||||
raise ToolError("network_id is required for 'network_details' action")
|
raise ToolError("network_id is required for 'network_details' action")
|
||||||
|
|
||||||
try:
|
if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
|
||||||
|
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||||
|
|
||||||
|
with tool_error_handler("docker", action, logger):
|
||||||
logger.info(f"Executing unraid_docker action={action}")
|
logger.info(f"Executing unraid_docker action={action}")
|
||||||
|
|
||||||
# --- Read-only queries ---
|
# --- Read-only queries ---
|
||||||
if action == "list":
|
if action == "list":
|
||||||
data = await make_graphql_request(QUERIES["list"])
|
data = await make_graphql_request(QUERIES["list"])
|
||||||
containers = _safe_get(data, "docker", "containers", default=[])
|
containers = safe_get(data, "docker", "containers", default=[])
|
||||||
return {"containers": list(containers) if isinstance(containers, list) else []}
|
return {"containers": containers}
|
||||||
|
|
||||||
if action == "details":
|
if action == "details":
|
||||||
|
# Resolve name -> ID first (skips list fetch if already an ID)
|
||||||
|
actual_id = await _resolve_container_id(container_id or "")
|
||||||
data = await make_graphql_request(QUERIES["details"])
|
data = await make_graphql_request(QUERIES["details"])
|
||||||
containers = _safe_get(data, "docker", "containers", default=[])
|
containers = safe_get(data, "docker", "containers", default=[])
|
||||||
container = find_container_by_identifier(container_id or "", containers)
|
# Match by resolved ID (exact match, no second list fetch needed)
|
||||||
if container:
|
for c in containers:
|
||||||
return container
|
if c.get("id") == actual_id:
|
||||||
available = get_available_container_names(containers)
|
return c
|
||||||
msg = f"Container '{container_id}' not found."
|
raise ToolError(f"Container '{container_id}' not found in details response.")
|
||||||
if available:
|
|
||||||
msg += f" Available: {', '.join(available[:10])}"
|
|
||||||
raise ToolError(msg)
|
|
||||||
|
|
||||||
if action == "logs":
|
if action == "logs":
|
||||||
actual_id = await _resolve_container_id(container_id or "")
|
actual_id = await _resolve_container_id(container_id or "")
|
||||||
data = await make_graphql_request(
|
data = await make_graphql_request(
|
||||||
QUERIES["logs"], {"id": actual_id, "tail": tail_lines}
|
QUERIES["logs"], {"id": actual_id, "tail": tail_lines}
|
||||||
)
|
)
|
||||||
return {"logs": _safe_get(data, "docker", "logs")}
|
return {"logs": safe_get(data, "docker", "logs")}
|
||||||
|
|
||||||
if action == "networks":
|
if action == "networks":
|
||||||
data = await make_graphql_request(QUERIES["networks"])
|
data = await make_graphql_request(QUERIES["networks"])
|
||||||
networks = data.get("dockerNetworks", [])
|
networks = safe_get(data, "dockerNetworks", default=[])
|
||||||
return {"networks": list(networks) if isinstance(networks, list) else []}
|
return {"networks": networks}
|
||||||
|
|
||||||
if action == "network_details":
|
if action == "network_details":
|
||||||
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
|
data = await make_graphql_request(QUERIES["network_details"], {"id": network_id})
|
||||||
return dict(data.get("dockerNetwork") or {})
|
return dict(safe_get(data, "dockerNetwork", default={}) or {})
|
||||||
|
|
||||||
if action == "port_conflicts":
|
if action == "port_conflicts":
|
||||||
data = await make_graphql_request(QUERIES["port_conflicts"])
|
data = await make_graphql_request(QUERIES["port_conflicts"])
|
||||||
conflicts = _safe_get(data, "docker", "portConflicts", default=[])
|
conflicts = safe_get(data, "docker", "portConflicts", default=[])
|
||||||
return {"port_conflicts": list(conflicts) if isinstance(conflicts, list) else []}
|
return {"port_conflicts": conflicts}
|
||||||
|
|
||||||
if action == "check_updates":
|
if action == "check_updates":
|
||||||
data = await make_graphql_request(QUERIES["check_updates"])
|
data = await make_graphql_request(QUERIES["check_updates"])
|
||||||
statuses = _safe_get(data, "docker", "containerUpdateStatuses", default=[])
|
statuses = safe_get(data, "docker", "containerUpdateStatuses", default=[])
|
||||||
return {"update_statuses": list(statuses) if isinstance(statuses, list) else []}
|
return {"update_statuses": statuses}
|
||||||
|
|
||||||
# --- Mutations ---
|
# --- Mutations (strict matching: no fuzzy/substring) ---
|
||||||
if action == "restart":
|
if action == "restart":
|
||||||
actual_id = await _resolve_container_id(container_id or "")
|
actual_id = await _resolve_container_id(container_id or "", strict=True)
|
||||||
# Stop (idempotent: treat "already stopped" as success)
|
# Stop (idempotent: treat "already stopped" as success)
|
||||||
stop_data = await make_graphql_request(
|
stop_data = await make_graphql_request(
|
||||||
MUTATIONS["stop"],
|
MUTATIONS["stop"],
|
||||||
@@ -330,7 +383,7 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
|||||||
if start_data.get("idempotent_success"):
|
if start_data.get("idempotent_success"):
|
||||||
result = {}
|
result = {}
|
||||||
else:
|
else:
|
||||||
result = _safe_get(start_data, "docker", "start", default={})
|
result = safe_get(start_data, "docker", "start", default={})
|
||||||
response: dict[str, Any] = {
|
response: dict[str, Any] = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"action": "restart",
|
"action": "restart",
|
||||||
@@ -342,12 +395,12 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
if action == "update_all":
|
if action == "update_all":
|
||||||
data = await make_graphql_request(MUTATIONS["update_all"])
|
data = await make_graphql_request(MUTATIONS["update_all"])
|
||||||
results = _safe_get(data, "docker", "updateAllContainers", default=[])
|
results = safe_get(data, "docker", "updateAllContainers", default=[])
|
||||||
return {"success": True, "action": "update_all", "containers": results}
|
return {"success": True, "action": "update_all", "containers": results}
|
||||||
|
|
||||||
# Single-container mutations
|
# Single-container mutations
|
||||||
if action in MUTATIONS:
|
if action in MUTATIONS:
|
||||||
actual_id = await _resolve_container_id(container_id or "")
|
actual_id = await _resolve_container_id(container_id or "", strict=True)
|
||||||
op_context: dict[str, str] | None = (
|
op_context: dict[str, str] | None = (
|
||||||
{"operation": action} if action in ("start", "stop") else None
|
{"operation": action} if action in ("start", "stop") else None
|
||||||
)
|
)
|
||||||
@@ -382,10 +435,4 @@ def register_docker_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_docker action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute docker/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Docker tool registered successfully")
|
logger.info("Docker tool registered successfully")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ connection testing, and subscription diagnostics.
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -19,11 +19,22 @@ from ..config.settings import (
|
|||||||
VERSION,
|
VERSION,
|
||||||
)
|
)
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
from ..core.utils import safe_display_url
|
||||||
|
|
||||||
|
|
||||||
|
ALL_ACTIONS = {"check", "test_connection", "diagnose"}
|
||||||
|
|
||||||
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
|
HEALTH_ACTIONS = Literal["check", "test_connection", "diagnose"]
|
||||||
|
|
||||||
|
if set(get_args(HEALTH_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(HEALTH_ACTIONS))
|
||||||
|
_extra = set(get_args(HEALTH_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
"HEALTH_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing in HEALTH_ACTIONS: {_missing}; extra in HEALTH_ACTIONS: {_extra}"
|
||||||
|
)
|
||||||
|
|
||||||
# Severity ordering: only upgrade, never downgrade
|
# Severity ordering: only upgrade, never downgrade
|
||||||
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
|
_SEVERITY = {"healthy": 0, "warning": 1, "degraded": 2, "unhealthy": 3}
|
||||||
|
|
||||||
@@ -53,12 +64,10 @@ def register_health_tool(mcp: FastMCP) -> None:
|
|||||||
test_connection - Quick connectivity test (just checks { online })
|
test_connection - Quick connectivity test (just checks { online })
|
||||||
diagnose - Subscription system diagnostics
|
diagnose - Subscription system diagnostics
|
||||||
"""
|
"""
|
||||||
if action not in ("check", "test_connection", "diagnose"):
|
if action not in ALL_ACTIONS:
|
||||||
raise ToolError(
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
f"Invalid action '{action}'. Must be one of: check, test_connection, diagnose"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
with tool_error_handler("health", action, logger):
|
||||||
logger.info(f"Executing unraid_health action={action}")
|
logger.info(f"Executing unraid_health action={action}")
|
||||||
|
|
||||||
if action == "test_connection":
|
if action == "test_connection":
|
||||||
@@ -79,12 +88,6 @@ def register_health_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_health action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute health/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Health tool registered successfully")
|
logger.info("Health tool registered successfully")
|
||||||
|
|
||||||
|
|
||||||
@@ -111,7 +114,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
|||||||
overview { unread { alert warning total } }
|
overview { unread { alert warning total } }
|
||||||
}
|
}
|
||||||
docker {
|
docker {
|
||||||
containers(skipCache: true) { id state status }
|
containers { id state status }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
@@ -135,7 +138,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
|||||||
if info:
|
if info:
|
||||||
health_info["unraid_system"] = {
|
health_info["unraid_system"] = {
|
||||||
"status": "connected",
|
"status": "connected",
|
||||||
"url": UNRAID_API_URL,
|
"url": safe_display_url(UNRAID_API_URL),
|
||||||
"machine_id": info.get("machineId"),
|
"machine_id": info.get("machineId"),
|
||||||
"version": info.get("versions", {}).get("unraid"),
|
"version": info.get("versions", {}).get("unraid"),
|
||||||
"uptime": info.get("os", {}).get("uptime"),
|
"uptime": info.get("os", {}).get("uptime"),
|
||||||
@@ -206,7 +209,7 @@ async def _comprehensive_check() -> dict[str, Any]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Intentionally broad: health checks must always return a result,
|
# Intentionally broad: health checks must always return a result,
|
||||||
# even on unexpected failures, so callers never get an unhandled exception.
|
# even on unexpected failures, so callers never get an unhandled exception.
|
||||||
logger.error(f"Health check failed: {e}")
|
logger.error(f"Health check failed: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
"status": "unhealthy",
|
"status": "unhealthy",
|
||||||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||||
@@ -215,6 +218,42 @@ async def _comprehensive_check() -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _analyze_subscription_status(
|
||||||
|
status: dict[str, Any],
|
||||||
|
) -> tuple[int, list[dict[str, Any]]]:
|
||||||
|
"""Analyze subscription status dict, returning error count and connection issues.
|
||||||
|
|
||||||
|
This is the canonical implementation of subscription status analysis.
|
||||||
|
TODO: subscriptions/diagnostics.py has a similar status-analysis pattern
|
||||||
|
in diagnose_subscriptions(). That module could import and call this helper
|
||||||
|
directly to avoid divergence. See Code-H05.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Dict of subscription name -> status info from get_subscription_status().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_count, connection_issues_list).
|
||||||
|
"""
|
||||||
|
error_count = 0
|
||||||
|
connection_issues: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for sub_name, sub_status in status.items():
|
||||||
|
runtime = sub_status.get("runtime", {})
|
||||||
|
conn_state = runtime.get("connection_state", "unknown")
|
||||||
|
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
|
||||||
|
error_count += 1
|
||||||
|
if runtime.get("last_error"):
|
||||||
|
connection_issues.append(
|
||||||
|
{
|
||||||
|
"subscription": sub_name,
|
||||||
|
"state": conn_state,
|
||||||
|
"error": runtime["last_error"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return error_count, connection_issues
|
||||||
|
|
||||||
|
|
||||||
async def _diagnose_subscriptions() -> dict[str, Any]:
|
async def _diagnose_subscriptions() -> dict[str, Any]:
|
||||||
"""Import and run subscription diagnostics."""
|
"""Import and run subscription diagnostics."""
|
||||||
try:
|
try:
|
||||||
@@ -223,13 +262,10 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
|
|||||||
|
|
||||||
await ensure_subscriptions_started()
|
await ensure_subscriptions_started()
|
||||||
|
|
||||||
status = subscription_manager.get_subscription_status()
|
status = await subscription_manager.get_subscription_status()
|
||||||
# This list is intentionally placed into the summary dict below and then
|
error_count, connection_issues = _analyze_subscription_status(status)
|
||||||
# appended to in the loop — the mutable alias ensures both references
|
|
||||||
# reflect the same data without a second pass.
|
|
||||||
connection_issues: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
diagnostic_info: dict[str, Any] = {
|
return {
|
||||||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||||
"environment": {
|
"environment": {
|
||||||
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
"auto_start_enabled": subscription_manager.auto_start_enabled,
|
||||||
@@ -241,31 +277,12 @@ async def _diagnose_subscriptions() -> dict[str, Any]:
|
|||||||
"total_configured": len(subscription_manager.subscription_configs),
|
"total_configured": len(subscription_manager.subscription_configs),
|
||||||
"active_count": len(subscription_manager.active_subscriptions),
|
"active_count": len(subscription_manager.active_subscriptions),
|
||||||
"with_data": len(subscription_manager.resource_data),
|
"with_data": len(subscription_manager.resource_data),
|
||||||
"in_error_state": 0,
|
"in_error_state": error_count,
|
||||||
"connection_issues": connection_issues,
|
"connection_issues": connection_issues,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for sub_name, sub_status in status.items():
|
except ImportError as e:
|
||||||
runtime = sub_status.get("runtime", {})
|
raise ToolError("Subscription modules not available") from e
|
||||||
conn_state = runtime.get("connection_state", "unknown")
|
|
||||||
if conn_state in ("error", "auth_failed", "timeout", "max_retries_exceeded"):
|
|
||||||
diagnostic_info["summary"]["in_error_state"] += 1
|
|
||||||
if runtime.get("last_error"):
|
|
||||||
connection_issues.append(
|
|
||||||
{
|
|
||||||
"subscription": sub_name,
|
|
||||||
"state": conn_state,
|
|
||||||
"error": runtime["last_error"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return diagnostic_info
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
return {
|
|
||||||
"error": "Subscription modules not available",
|
|
||||||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e
|
raise ToolError(f"Failed to generate diagnostics: {e!s}") from e
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ Provides the `unraid_info` tool with 19 read-only actions for retrieving
|
|||||||
system information, array status, network config, and server metadata.
|
system information, array status, network config, and server metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
from ..core.utils import format_kb
|
||||||
|
|
||||||
|
|
||||||
# Pre-built queries keyed by action name
|
# Pre-built queries keyed by action name
|
||||||
@@ -19,7 +20,7 @@ QUERIES: dict[str, str] = {
|
|||||||
query GetSystemInfo {
|
query GetSystemInfo {
|
||||||
info {
|
info {
|
||||||
os { platform distro release codename kernel arch hostname codepage logofile serial build uptime }
|
os { platform distro release codename kernel arch hostname codepage logofile serial build uptime }
|
||||||
cpu { manufacturer brand vendor family model stepping revision voltage speed speedmin speedmax threads cores processors socket cache flags }
|
cpu { manufacturer brand vendor family model stepping revision voltage speed speedmin speedmax threads cores processors socket cache }
|
||||||
memory {
|
memory {
|
||||||
layout { bank type clockSpeed formFactor manufacturer partNum serialNum }
|
layout { bank type clockSpeed formFactor manufacturer partNum serialNum }
|
||||||
}
|
}
|
||||||
@@ -81,7 +82,6 @@ QUERIES: dict[str, str] = {
|
|||||||
shareAvahiEnabled safeMode startMode configValid configError joinStatus
|
shareAvahiEnabled safeMode startMode configValid configError joinStatus
|
||||||
deviceCount flashGuid flashProduct flashVendor mdState mdVersion
|
deviceCount flashGuid flashProduct flashVendor mdState mdVersion
|
||||||
shareCount shareSmbCount shareNfsCount shareAfpCount shareMoverActive
|
shareCount shareSmbCount shareNfsCount shareAfpCount shareMoverActive
|
||||||
csrfToken
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
@@ -156,6 +156,8 @@ QUERIES: dict[str, str] = {
|
|||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ALL_ACTIONS = set(QUERIES)
|
||||||
|
|
||||||
INFO_ACTIONS = Literal[
|
INFO_ACTIONS = Literal[
|
||||||
"overview",
|
"overview",
|
||||||
"array",
|
"array",
|
||||||
@@ -178,9 +180,13 @@ INFO_ACTIONS = Literal[
|
|||||||
"ups_config",
|
"ups_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
assert set(QUERIES.keys()) == set(INFO_ACTIONS.__args__), (
|
if set(get_args(INFO_ACTIONS)) != ALL_ACTIONS:
|
||||||
"QUERIES keys and INFO_ACTIONS are out of sync"
|
_missing = ALL_ACTIONS - set(get_args(INFO_ACTIONS))
|
||||||
)
|
_extra = set(get_args(INFO_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"QUERIES keys and INFO_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
|
def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -189,17 +195,17 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
|
|||||||
if raw_info.get("os"):
|
if raw_info.get("os"):
|
||||||
os_info = raw_info["os"]
|
os_info = raw_info["os"]
|
||||||
summary["os"] = (
|
summary["os"] = (
|
||||||
f"{os_info.get('distro', '')} {os_info.get('release', '')} "
|
f"{os_info.get('distro') or 'unknown'} {os_info.get('release') or 'unknown'} "
|
||||||
f"({os_info.get('platform', '')}, {os_info.get('arch', '')})"
|
f"({os_info.get('platform') or 'unknown'}, {os_info.get('arch') or 'unknown'})"
|
||||||
)
|
)
|
||||||
summary["hostname"] = os_info.get("hostname")
|
summary["hostname"] = os_info.get("hostname") or "unknown"
|
||||||
summary["uptime"] = os_info.get("uptime")
|
summary["uptime"] = os_info.get("uptime")
|
||||||
|
|
||||||
if raw_info.get("cpu"):
|
if raw_info.get("cpu"):
|
||||||
cpu = raw_info["cpu"]
|
cpu = raw_info["cpu"]
|
||||||
summary["cpu"] = (
|
summary["cpu"] = (
|
||||||
f"{cpu.get('manufacturer', '')} {cpu.get('brand', '')} "
|
f"{cpu.get('manufacturer') or 'unknown'} {cpu.get('brand') or 'unknown'} "
|
||||||
f"({cpu.get('cores', '?')} cores, {cpu.get('threads', '?')} threads)"
|
f"({cpu.get('cores') or '?'} cores, {cpu.get('threads') or '?'} threads)"
|
||||||
)
|
)
|
||||||
|
|
||||||
if raw_info.get("memory") and raw_info["memory"].get("layout"):
|
if raw_info.get("memory") and raw_info["memory"].get("layout"):
|
||||||
@@ -207,10 +213,10 @@ def _process_system_info(raw_info: dict[str, Any]) -> dict[str, Any]:
|
|||||||
summary["memory_layout_details"] = []
|
summary["memory_layout_details"] = []
|
||||||
for stick in mem_layout:
|
for stick in mem_layout:
|
||||||
summary["memory_layout_details"].append(
|
summary["memory_layout_details"].append(
|
||||||
f"Bank {stick.get('bank', '?')}: Type {stick.get('type', '?')}, "
|
f"Bank {stick.get('bank') or '?'}: Type {stick.get('type') or '?'}, "
|
||||||
f"Speed {stick.get('clockSpeed', '?')}MHz, "
|
f"Speed {stick.get('clockSpeed') or '?'}MHz, "
|
||||||
f"Manufacturer: {stick.get('manufacturer', '?')}, "
|
f"Manufacturer: {stick.get('manufacturer') or '?'}, "
|
||||||
f"Part: {stick.get('partNum', '?')}"
|
f"Part: {stick.get('partNum') or '?'}"
|
||||||
)
|
)
|
||||||
summary["memory_summary"] = (
|
summary["memory_summary"] = (
|
||||||
"Stick layout details retrieved. Overall total/used/free memory stats "
|
"Stick layout details retrieved. Overall total/used/free memory stats "
|
||||||
@@ -255,31 +261,14 @@ def _analyze_disk_health(disks: list[dict[str, Any]]) -> dict[str, int]:
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
|
|
||||||
def _format_kb(k: Any) -> str:
|
|
||||||
"""Format kilobyte values into human-readable sizes."""
|
|
||||||
if k is None:
|
|
||||||
return "N/A"
|
|
||||||
try:
|
|
||||||
k = int(k)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return "N/A"
|
|
||||||
if k >= 1024 * 1024 * 1024:
|
|
||||||
return f"{k / (1024 * 1024 * 1024):.2f} TB"
|
|
||||||
if k >= 1024 * 1024:
|
|
||||||
return f"{k / (1024 * 1024):.2f} GB"
|
|
||||||
if k >= 1024:
|
|
||||||
return f"{k / 1024:.2f} MB"
|
|
||||||
return f"{k} KB"
|
|
||||||
|
|
||||||
|
|
||||||
def _process_array_status(raw: dict[str, Any]) -> dict[str, Any]:
|
def _process_array_status(raw: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Process raw array data into summary + details."""
|
"""Process raw array data into summary + details."""
|
||||||
summary: dict[str, Any] = {"state": raw.get("state")}
|
summary: dict[str, Any] = {"state": raw.get("state")}
|
||||||
if raw.get("capacity") and raw["capacity"].get("kilobytes"):
|
if raw.get("capacity") and raw["capacity"].get("kilobytes"):
|
||||||
kb = raw["capacity"]["kilobytes"]
|
kb = raw["capacity"]["kilobytes"]
|
||||||
summary["capacity_total"] = _format_kb(kb.get("total"))
|
summary["capacity_total"] = format_kb(kb.get("total"))
|
||||||
summary["capacity_used"] = _format_kb(kb.get("used"))
|
summary["capacity_used"] = format_kb(kb.get("used"))
|
||||||
summary["capacity_free"] = _format_kb(kb.get("free"))
|
summary["capacity_free"] = format_kb(kb.get("free"))
|
||||||
|
|
||||||
summary["num_parity_disks"] = len(raw.get("parities", []))
|
summary["num_parity_disks"] = len(raw.get("parities", []))
|
||||||
summary["num_data_disks"] = len(raw.get("disks", []))
|
summary["num_data_disks"] = len(raw.get("disks", []))
|
||||||
@@ -345,8 +334,8 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
ups_device - Single UPS device (requires device_id)
|
ups_device - Single UPS device (requires device_id)
|
||||||
ups_config - UPS configuration
|
ups_config - UPS configuration
|
||||||
"""
|
"""
|
||||||
if action not in QUERIES:
|
if action not in ALL_ACTIONS:
|
||||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {list(QUERIES.keys())}")
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
|
|
||||||
if action == "ups_device" and not device_id:
|
if action == "ups_device" and not device_id:
|
||||||
raise ToolError("device_id is required for ups_device action")
|
raise ToolError("device_id is required for ups_device action")
|
||||||
@@ -377,7 +366,7 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
"ups_devices": ("upsDevices", "ups_devices"),
|
"ups_devices": ("upsDevices", "ups_devices"),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
with tool_error_handler("info", action, logger):
|
||||||
logger.info(f"Executing unraid_info action={action}")
|
logger.info(f"Executing unraid_info action={action}")
|
||||||
data = await make_graphql_request(query, variables)
|
data = await make_graphql_request(query, variables)
|
||||||
|
|
||||||
@@ -426,14 +415,9 @@ def register_info_tool(mcp: FastMCP) -> None:
|
|||||||
if action in list_actions:
|
if action in list_actions:
|
||||||
response_key, output_key = list_actions[action]
|
response_key, output_key = list_actions[action]
|
||||||
items = data.get(response_key) or []
|
items = data.get(response_key) or []
|
||||||
return {output_key: list(items) if isinstance(items, list) else []}
|
normalized_items = list(items) if isinstance(items, list) else []
|
||||||
|
return {output_key: normalized_items}
|
||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_info action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute info/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Info tool registered successfully")
|
logger.info("Info tool registered successfully")
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ Provides the `unraid_keys` tool with 5 actions for listing, viewing,
|
|||||||
creating, updating, and deleting API keys.
|
creating, updating, and deleting API keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -45,6 +45,7 @@ MUTATIONS: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DESTRUCTIVE_ACTIONS = {"delete"}
|
DESTRUCTIVE_ACTIONS = {"delete"}
|
||||||
|
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||||
|
|
||||||
KEY_ACTIONS = Literal[
|
KEY_ACTIONS = Literal[
|
||||||
"list",
|
"list",
|
||||||
@@ -54,6 +55,14 @@ KEY_ACTIONS = Literal[
|
|||||||
"delete",
|
"delete",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(KEY_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(KEY_ACTIONS))
|
||||||
|
_extra = set(get_args(KEY_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"KEY_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_keys_tool(mcp: FastMCP) -> None:
|
def register_keys_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_keys tool with the FastMCP instance."""
|
"""Register the unraid_keys tool with the FastMCP instance."""
|
||||||
@@ -76,14 +85,13 @@ def register_keys_tool(mcp: FastMCP) -> None:
|
|||||||
update - Update an API key (requires key_id; optional name, roles)
|
update - Update an API key (requires key_id; optional name, roles)
|
||||||
delete - Delete API keys (requires key_id, confirm=True)
|
delete - Delete API keys (requires key_id, confirm=True)
|
||||||
"""
|
"""
|
||||||
all_actions = set(QUERIES) | set(MUTATIONS)
|
if action not in ALL_ACTIONS:
|
||||||
if action not in all_actions:
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(all_actions)}")
|
|
||||||
|
|
||||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||||
|
|
||||||
try:
|
with tool_error_handler("keys", action, logger):
|
||||||
logger.info(f"Executing unraid_keys action={action}")
|
logger.info(f"Executing unraid_keys action={action}")
|
||||||
|
|
||||||
if action == "list":
|
if action == "list":
|
||||||
@@ -141,10 +149,4 @@ def register_keys_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_keys action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute keys/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Keys tool registered successfully")
|
logger.info("Keys tool registered successfully")
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ Provides the `unraid_notifications` tool with 9 actions for viewing,
|
|||||||
creating, archiving, and deleting system notifications.
|
creating, archiving, and deleting system notifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -76,6 +76,8 @@ MUTATIONS: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"}
|
DESTRUCTIVE_ACTIONS = {"delete", "delete_archived"}
|
||||||
|
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||||
|
_VALID_IMPORTANCE = {"ALERT", "WARNING", "NORMAL"}
|
||||||
|
|
||||||
NOTIFICATION_ACTIONS = Literal[
|
NOTIFICATION_ACTIONS = Literal[
|
||||||
"overview",
|
"overview",
|
||||||
@@ -89,6 +91,14 @@ NOTIFICATION_ACTIONS = Literal[
|
|||||||
"archive_all",
|
"archive_all",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(NOTIFICATION_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(NOTIFICATION_ACTIONS))
|
||||||
|
_extra = set(get_args(NOTIFICATION_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"NOTIFICATION_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_notifications_tool(mcp: FastMCP) -> None:
|
def register_notifications_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_notifications tool with the FastMCP instance."""
|
"""Register the unraid_notifications tool with the FastMCP instance."""
|
||||||
@@ -120,16 +130,13 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
|||||||
delete_archived - Delete all archived notifications (requires confirm=True)
|
delete_archived - Delete all archived notifications (requires confirm=True)
|
||||||
archive_all - Archive all notifications (optional importance filter)
|
archive_all - Archive all notifications (optional importance filter)
|
||||||
"""
|
"""
|
||||||
all_actions = {**QUERIES, **MUTATIONS}
|
if action not in ALL_ACTIONS:
|
||||||
if action not in all_actions:
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
raise ToolError(
|
|
||||||
f"Invalid action '{action}'. Must be one of: {list(all_actions.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||||
|
|
||||||
try:
|
with tool_error_handler("notifications", action, logger):
|
||||||
logger.info(f"Executing unraid_notifications action={action}")
|
logger.info(f"Executing unraid_notifications action={action}")
|
||||||
|
|
||||||
if action == "overview":
|
if action == "overview":
|
||||||
@@ -147,18 +154,29 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
|||||||
filter_vars["importance"] = importance.upper()
|
filter_vars["importance"] = importance.upper()
|
||||||
data = await make_graphql_request(QUERIES["list"], {"filter": filter_vars})
|
data = await make_graphql_request(QUERIES["list"], {"filter": filter_vars})
|
||||||
notifications = data.get("notifications", {})
|
notifications = data.get("notifications", {})
|
||||||
result = notifications.get("list", [])
|
return {"notifications": notifications.get("list", [])}
|
||||||
return {"notifications": list(result) if isinstance(result, list) else []}
|
|
||||||
|
|
||||||
if action == "warnings":
|
if action == "warnings":
|
||||||
data = await make_graphql_request(QUERIES["warnings"])
|
data = await make_graphql_request(QUERIES["warnings"])
|
||||||
notifications = data.get("notifications", {})
|
notifications = data.get("notifications", {})
|
||||||
result = notifications.get("warningsAndAlerts", [])
|
return {"warnings": notifications.get("warningsAndAlerts", [])}
|
||||||
return {"warnings": list(result) if isinstance(result, list) else []}
|
|
||||||
|
|
||||||
if action == "create":
|
if action == "create":
|
||||||
if title is None or subject is None or description is None or importance is None:
|
if title is None or subject is None or description is None or importance is None:
|
||||||
raise ToolError("create requires title, subject, description, and importance")
|
raise ToolError("create requires title, subject, description, and importance")
|
||||||
|
if importance.upper() not in _VALID_IMPORTANCE:
|
||||||
|
raise ToolError(
|
||||||
|
f"importance must be one of: {', '.join(sorted(_VALID_IMPORTANCE))}. "
|
||||||
|
f"Got: '{importance}'"
|
||||||
|
)
|
||||||
|
if len(title) > 200:
|
||||||
|
raise ToolError(f"title must be at most 200 characters (got {len(title)})")
|
||||||
|
if len(subject) > 500:
|
||||||
|
raise ToolError(f"subject must be at most 500 characters (got {len(subject)})")
|
||||||
|
if len(description) > 2000:
|
||||||
|
raise ToolError(
|
||||||
|
f"description must be at most 2000 characters (got {len(description)})"
|
||||||
|
)
|
||||||
input_data = {
|
input_data = {
|
||||||
"title": title,
|
"title": title,
|
||||||
"subject": subject,
|
"subject": subject,
|
||||||
@@ -196,10 +214,4 @@ def register_notifications_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_notifications action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute notifications/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Notifications tool registered successfully")
|
logger.info("Notifications tool registered successfully")
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ Provides the `unraid_rclone` tool with 4 actions for managing
|
|||||||
cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.).
|
cloud storage remotes (S3, Google Drive, Dropbox, FTP, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
import re
|
||||||
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -49,6 +50,59 @@ RCLONE_ACTIONS = Literal[
|
|||||||
"delete_remote",
|
"delete_remote",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(RCLONE_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(RCLONE_ACTIONS))
|
||||||
|
_extra = set(get_args(RCLONE_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"RCLONE_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Max config entries to prevent abuse
|
||||||
|
_MAX_CONFIG_KEYS = 50
|
||||||
|
# Pattern for suspicious key names (path traversal, shell metacharacters)
|
||||||
|
_DANGEROUS_KEY_PATTERN = re.compile(r"\.\.|[/\\;|`$(){}]")
|
||||||
|
# Max length for individual config values
|
||||||
|
_MAX_VALUE_LENGTH = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_config_data(config_data: dict[str, Any]) -> dict[str, str]:
|
||||||
|
"""Validate and sanitize rclone config_data before passing to GraphQL.
|
||||||
|
|
||||||
|
Ensures all keys and values are safe strings with no injection vectors.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolError: If config_data contains invalid keys or values
|
||||||
|
"""
|
||||||
|
if len(config_data) > _MAX_CONFIG_KEYS:
|
||||||
|
raise ToolError(f"config_data has {len(config_data)} keys (max {_MAX_CONFIG_KEYS})")
|
||||||
|
|
||||||
|
validated: dict[str, str] = {}
|
||||||
|
for key, value in config_data.items():
|
||||||
|
if not isinstance(key, str) or not key.strip():
|
||||||
|
raise ToolError(
|
||||||
|
f"config_data keys must be non-empty strings, got: {type(key).__name__}"
|
||||||
|
)
|
||||||
|
if _DANGEROUS_KEY_PATTERN.search(key):
|
||||||
|
raise ToolError(
|
||||||
|
f"config_data key '{key}' contains disallowed characters "
|
||||||
|
f"(path traversal or shell metacharacters)"
|
||||||
|
)
|
||||||
|
if not isinstance(value, (str, int, float, bool)):
|
||||||
|
raise ToolError(
|
||||||
|
f"config_data['{key}'] must be a string, number, or boolean, "
|
||||||
|
f"got: {type(value).__name__}"
|
||||||
|
)
|
||||||
|
str_value = str(value)
|
||||||
|
if len(str_value) > _MAX_VALUE_LENGTH:
|
||||||
|
raise ToolError(
|
||||||
|
f"config_data['{key}'] value exceeds max length "
|
||||||
|
f"({len(str_value)} > {_MAX_VALUE_LENGTH})"
|
||||||
|
)
|
||||||
|
validated[key] = str_value
|
||||||
|
|
||||||
|
return validated
|
||||||
|
|
||||||
|
|
||||||
def register_rclone_tool(mcp: FastMCP) -> None:
|
def register_rclone_tool(mcp: FastMCP) -> None:
|
||||||
"""Register the unraid_rclone tool with the FastMCP instance."""
|
"""Register the unraid_rclone tool with the FastMCP instance."""
|
||||||
@@ -75,7 +129,7 @@ def register_rclone_tool(mcp: FastMCP) -> None:
|
|||||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||||
|
|
||||||
try:
|
with tool_error_handler("rclone", action, logger):
|
||||||
logger.info(f"Executing unraid_rclone action={action}")
|
logger.info(f"Executing unraid_rclone action={action}")
|
||||||
|
|
||||||
if action == "list_remotes":
|
if action == "list_remotes":
|
||||||
@@ -96,9 +150,10 @@ def register_rclone_tool(mcp: FastMCP) -> None:
|
|||||||
if action == "create_remote":
|
if action == "create_remote":
|
||||||
if name is None or provider_type is None or config_data is None:
|
if name is None or provider_type is None or config_data is None:
|
||||||
raise ToolError("create_remote requires name, provider_type, and config_data")
|
raise ToolError("create_remote requires name, provider_type, and config_data")
|
||||||
|
validated_config = _validate_config_data(config_data)
|
||||||
data = await make_graphql_request(
|
data = await make_graphql_request(
|
||||||
MUTATIONS["create_remote"],
|
MUTATIONS["create_remote"],
|
||||||
{"input": {"name": name, "type": provider_type, "config": config_data}},
|
{"input": {"name": name, "type": provider_type, "config": validated_config}},
|
||||||
)
|
)
|
||||||
remote = data.get("rclone", {}).get("createRCloneRemote")
|
remote = data.get("rclone", {}).get("createRCloneRemote")
|
||||||
if not remote:
|
if not remote:
|
||||||
@@ -127,10 +182,4 @@ def register_rclone_tool(mcp: FastMCP) -> None:
|
|||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_rclone action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute rclone/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("RClone tool registered successfully")
|
logger.info("RClone tool registered successfully")
|
||||||
|
|||||||
@@ -4,17 +4,19 @@ Provides the `unraid_storage` tool with 6 actions for shares, physical disks,
|
|||||||
unassigned devices, log files, and log content retrieval.
|
unassigned devices, log files, and log content retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
import os
|
||||||
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
import anyio
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import DISK_TIMEOUT, make_graphql_request
|
from ..core.client import DISK_TIMEOUT, make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
from ..core.utils import format_bytes
|
||||||
|
|
||||||
|
|
||||||
_ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/", "/mnt/")
|
_ALLOWED_LOG_PREFIXES = ("/var/log/", "/boot/logs/", "/mnt/")
|
||||||
|
_MAX_TAIL_LINES = 10_000
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
"shares": """
|
"shares": """
|
||||||
@@ -56,6 +58,8 @@ QUERIES: dict[str, str] = {
|
|||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ALL_ACTIONS = set(QUERIES)
|
||||||
|
|
||||||
STORAGE_ACTIONS = Literal[
|
STORAGE_ACTIONS = Literal[
|
||||||
"shares",
|
"shares",
|
||||||
"disks",
|
"disks",
|
||||||
@@ -65,20 +69,13 @@ STORAGE_ACTIONS = Literal[
|
|||||||
"logs",
|
"logs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if set(get_args(STORAGE_ACTIONS)) != ALL_ACTIONS:
|
||||||
def format_bytes(bytes_value: int | None) -> str:
|
_missing = ALL_ACTIONS - set(get_args(STORAGE_ACTIONS))
|
||||||
"""Format byte values into human-readable sizes."""
|
_extra = set(get_args(STORAGE_ACTIONS)) - ALL_ACTIONS
|
||||||
if bytes_value is None:
|
raise RuntimeError(
|
||||||
return "N/A"
|
f"STORAGE_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
try:
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
value = float(int(bytes_value))
|
)
|
||||||
except (ValueError, TypeError):
|
|
||||||
return "N/A"
|
|
||||||
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
|
||||||
if value < 1024.0:
|
|
||||||
return f"{value:.2f} {unit}"
|
|
||||||
value /= 1024.0
|
|
||||||
return f"{value:.2f} EB"
|
|
||||||
|
|
||||||
|
|
||||||
def register_storage_tool(mcp: FastMCP) -> None:
|
def register_storage_tool(mcp: FastMCP) -> None:
|
||||||
@@ -101,17 +98,22 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
|||||||
log_files - List available log files
|
log_files - List available log files
|
||||||
logs - Retrieve log content (requires log_path, optional tail_lines)
|
logs - Retrieve log content (requires log_path, optional tail_lines)
|
||||||
"""
|
"""
|
||||||
if action not in QUERIES:
|
if action not in ALL_ACTIONS:
|
||||||
raise ToolError(f"Invalid action '{action}'. Must be one of: {list(QUERIES.keys())}")
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
|
|
||||||
if action == "disk_details" and not disk_id:
|
if action == "disk_details" and not disk_id:
|
||||||
raise ToolError("disk_id is required for 'disk_details' action")
|
raise ToolError("disk_id is required for 'disk_details' action")
|
||||||
|
|
||||||
|
if action == "logs" and (tail_lines < 1 or tail_lines > _MAX_TAIL_LINES):
|
||||||
|
raise ToolError(f"tail_lines must be between 1 and {_MAX_TAIL_LINES}, got {tail_lines}")
|
||||||
|
|
||||||
if action == "logs":
|
if action == "logs":
|
||||||
if not log_path:
|
if not log_path:
|
||||||
raise ToolError("log_path is required for 'logs' action")
|
raise ToolError("log_path is required for 'logs' action")
|
||||||
# Resolve path to prevent traversal attacks (e.g. /var/log/../../etc/shadow)
|
# Resolve path synchronously to prevent traversal attacks.
|
||||||
normalized = str(await anyio.Path(log_path).resolve())
|
# Using os.path.realpath instead of anyio.Path.resolve() because the
|
||||||
|
# async variant blocks on NFS-mounted paths under /mnt/ (Perf-AI-1).
|
||||||
|
normalized = os.path.realpath(log_path) # noqa: ASYNC240
|
||||||
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
|
if not any(normalized.startswith(p) for p in _ALLOWED_LOG_PREFIXES):
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}. "
|
f"log_path must start with one of: {', '.join(_ALLOWED_LOG_PREFIXES)}. "
|
||||||
@@ -128,17 +130,15 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
|||||||
elif action == "logs":
|
elif action == "logs":
|
||||||
variables = {"path": log_path, "lines": tail_lines}
|
variables = {"path": log_path, "lines": tail_lines}
|
||||||
|
|
||||||
try:
|
with tool_error_handler("storage", action, logger):
|
||||||
logger.info(f"Executing unraid_storage action={action}")
|
logger.info(f"Executing unraid_storage action={action}")
|
||||||
data = await make_graphql_request(query, variables, custom_timeout=custom_timeout)
|
data = await make_graphql_request(query, variables, custom_timeout=custom_timeout)
|
||||||
|
|
||||||
if action == "shares":
|
if action == "shares":
|
||||||
shares = data.get("shares", [])
|
return {"shares": data.get("shares", [])}
|
||||||
return {"shares": list(shares) if isinstance(shares, list) else []}
|
|
||||||
|
|
||||||
if action == "disks":
|
if action == "disks":
|
||||||
disks = data.get("disks", [])
|
return {"disks": data.get("disks", [])}
|
||||||
return {"disks": list(disks) if isinstance(disks, list) else []}
|
|
||||||
|
|
||||||
if action == "disk_details":
|
if action == "disk_details":
|
||||||
raw = data.get("disk", {})
|
raw = data.get("disk", {})
|
||||||
@@ -159,22 +159,14 @@ def register_storage_tool(mcp: FastMCP) -> None:
|
|||||||
return {"summary": summary, "details": raw}
|
return {"summary": summary, "details": raw}
|
||||||
|
|
||||||
if action == "unassigned":
|
if action == "unassigned":
|
||||||
devices = data.get("unassignedDevices", [])
|
return {"devices": data.get("unassignedDevices", [])}
|
||||||
return {"devices": list(devices) if isinstance(devices, list) else []}
|
|
||||||
|
|
||||||
if action == "log_files":
|
if action == "log_files":
|
||||||
files = data.get("logFiles", [])
|
return {"log_files": data.get("logFiles", [])}
|
||||||
return {"log_files": list(files) if isinstance(files, list) else []}
|
|
||||||
|
|
||||||
if action == "logs":
|
if action == "logs":
|
||||||
return dict(data.get("logFile") or {})
|
return dict(data.get("logFile") or {})
|
||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_storage action={action}: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute storage/{action}: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Storage tool registered successfully")
|
logger.info("Storage tool registered successfully")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from fastmcp import FastMCP
|
|||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -39,17 +39,11 @@ def register_users_tool(mcp: FastMCP) -> None:
|
|||||||
Note: Unraid API does not support user management operations (list, add, delete).
|
Note: Unraid API does not support user management operations (list, add, delete).
|
||||||
"""
|
"""
|
||||||
if action not in ALL_ACTIONS:
|
if action not in ALL_ACTIONS:
|
||||||
raise ToolError(f"Invalid action '{action}'. Must be: me")
|
raise ToolError(f"Invalid action '{action}'. Must be one of: {sorted(ALL_ACTIONS)}")
|
||||||
|
|
||||||
try:
|
with tool_error_handler("users", action, logger):
|
||||||
logger.info("Executing unraid_users action=me")
|
logger.info("Executing unraid_users action=me")
|
||||||
data = await make_graphql_request(QUERIES["me"])
|
data = await make_graphql_request(QUERIES["me"])
|
||||||
return data.get("me") or {}
|
return data.get("me") or {}
|
||||||
|
|
||||||
except ToolError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in unraid_users action=me: {e}", exc_info=True)
|
|
||||||
raise ToolError(f"Failed to execute users/me: {e!s}") from e
|
|
||||||
|
|
||||||
logger.info("Users tool registered successfully")
|
logger.info("Users tool registered successfully")
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ Provides the `unraid_vm` tool with 9 actions for VM lifecycle management
|
|||||||
including start, stop, pause, resume, force stop, reboot, and reset.
|
including start, stop, pause, resume, force stop, reboot, and reset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from ..config.logging import logger
|
from ..config.logging import logger
|
||||||
from ..core.client import make_graphql_request
|
from ..core.client import make_graphql_request
|
||||||
from ..core.exceptions import ToolError
|
from ..core.exceptions import ToolError, tool_error_handler
|
||||||
|
|
||||||
|
|
||||||
QUERIES: dict[str, str] = {
|
QUERIES: dict[str, str] = {
|
||||||
@@ -19,6 +19,13 @@ QUERIES: dict[str, str] = {
|
|||||||
vms { id domains { id name state uuid } }
|
vms { id domains { id name state uuid } }
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
|
# NOTE: The Unraid GraphQL API does not expose a single-VM query.
|
||||||
|
# The details query is identical to list; client-side filtering is required.
|
||||||
|
"details": """
|
||||||
|
query ListVMs {
|
||||||
|
vms { id domains { id name state uuid } }
|
||||||
|
}
|
||||||
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
MUTATIONS: dict[str, str] = {
|
MUTATIONS: dict[str, str] = {
|
||||||
@@ -64,7 +71,15 @@ VM_ACTIONS = Literal[
|
|||||||
"reset",
|
"reset",
|
||||||
]
|
]
|
||||||
|
|
||||||
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS) | {"details"}
|
ALL_ACTIONS = set(QUERIES) | set(MUTATIONS)
|
||||||
|
|
||||||
|
if set(get_args(VM_ACTIONS)) != ALL_ACTIONS:
|
||||||
|
_missing = ALL_ACTIONS - set(get_args(VM_ACTIONS))
|
||||||
|
_extra = set(get_args(VM_ACTIONS)) - ALL_ACTIONS
|
||||||
|
raise RuntimeError(
|
||||||
|
f"VM_ACTIONS and ALL_ACTIONS are out of sync. "
|
||||||
|
f"Missing from Literal: {_missing or 'none'}. Extra in Literal: {_extra or 'none'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_vm_tool(mcp: FastMCP) -> None:
|
def register_vm_tool(mcp: FastMCP) -> None:
|
||||||
@@ -98,20 +113,26 @@ def register_vm_tool(mcp: FastMCP) -> None:
|
|||||||
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
if action in DESTRUCTIVE_ACTIONS and not confirm:
|
||||||
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
raise ToolError(f"Action '{action}' is destructive. Set confirm=True to proceed.")
|
||||||
|
|
||||||
try:
|
with tool_error_handler("vm", action, logger):
|
||||||
logger.info(f"Executing unraid_vm action={action}")
|
try:
|
||||||
|
logger.info(f"Executing unraid_vm action={action}")
|
||||||
|
|
||||||
if action in ("list", "details"):
|
if action == "list":
|
||||||
data = await make_graphql_request(QUERIES["list"])
|
data = await make_graphql_request(QUERIES["list"])
|
||||||
if data.get("vms"):
|
if data.get("vms"):
|
||||||
|
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||||
|
if isinstance(vms, dict):
|
||||||
|
vms = [vms]
|
||||||
|
return {"vms": vms}
|
||||||
|
return {"vms": []}
|
||||||
|
|
||||||
|
if action == "details":
|
||||||
|
data = await make_graphql_request(QUERIES["details"])
|
||||||
|
if not data.get("vms"):
|
||||||
|
raise ToolError("No VM data returned from server")
|
||||||
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
vms = data["vms"].get("domains") or data["vms"].get("domain") or []
|
||||||
if isinstance(vms, dict):
|
if isinstance(vms, dict):
|
||||||
vms = [vms]
|
vms = [vms]
|
||||||
|
|
||||||
if action == "list":
|
|
||||||
return {"vms": vms}
|
|
||||||
|
|
||||||
# details: find specific VM
|
|
||||||
for vm in vms:
|
for vm in vms:
|
||||||
if (
|
if (
|
||||||
vm.get("uuid") == vm_id
|
vm.get("uuid") == vm_id
|
||||||
@@ -121,33 +142,28 @@ def register_vm_tool(mcp: FastMCP) -> None:
|
|||||||
return dict(vm)
|
return dict(vm)
|
||||||
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
|
available = [f"{v.get('name')} (UUID: {v.get('uuid')})" for v in vms]
|
||||||
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
|
raise ToolError(f"VM '{vm_id}' not found. Available: {', '.join(available)}")
|
||||||
if action == "details":
|
|
||||||
raise ToolError("No VM data returned from server")
|
|
||||||
return {"vms": []}
|
|
||||||
|
|
||||||
# Mutations
|
# Mutations
|
||||||
if action in MUTATIONS:
|
if action in MUTATIONS:
|
||||||
data = await make_graphql_request(MUTATIONS[action], {"id": vm_id})
|
data = await make_graphql_request(MUTATIONS[action], {"id": vm_id})
|
||||||
field = _MUTATION_FIELDS.get(action, action)
|
field = _MUTATION_FIELDS.get(action, action)
|
||||||
if data.get("vm") and field in data["vm"]:
|
if data.get("vm") and field in data["vm"]:
|
||||||
return {
|
return {
|
||||||
"success": data["vm"][field],
|
"success": data["vm"][field],
|
||||||
"action": action,
|
"action": action,
|
||||||
"vm_id": vm_id,
|
"vm_id": vm_id,
|
||||||
}
|
}
|
||||||
raise ToolError(f"Failed to {action} VM or unexpected response")
|
raise ToolError(f"Failed to {action} VM or unexpected response")
|
||||||
|
|
||||||
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
raise ToolError(f"Unhandled action '{action}' — this is a bug")
|
||||||
|
|
||||||
except ToolError:
|
except ToolError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in unraid_vm action={action}: {e}", exc_info=True)
|
if "VMs are not available" in str(e):
|
||||||
msg = str(e)
|
raise ToolError(
|
||||||
if "VMs are not available" in msg:
|
"VMs not available on this server. Check VM support is enabled."
|
||||||
raise ToolError(
|
) from e
|
||||||
"VMs not available on this server. Check VM support is enabled."
|
raise
|
||||||
) from e
|
|
||||||
raise ToolError(f"Failed to execute vm/{action}: {msg}") from e
|
|
||||||
|
|
||||||
logger.info("VM tool registered successfully")
|
logger.info("VM tool registered successfully")
|
||||||
|
|||||||
11
unraid_mcp/version.py
Normal file
11
unraid_mcp/version.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Application version helpers."""
|
||||||
|
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VERSION"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
VERSION = version("unraid-mcp")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
VERSION = "0.0.0"
|
||||||
13
uv.lock
generated
13
uv.lock
generated
@@ -1706,15 +1706,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" },
|
{ url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "types-pytz"
|
|
||||||
version = "2025.2.0.20251108"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/40/ff/c047ddc68c803b46470a357454ef76f4acd8c1088f5cc4891cdd909bfcf6/types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb", size = 10961, upload-time = "2025-11-08T02:55:57.001Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c", size = 10116, upload-time = "2025-11-08T02:55:56.194Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.15.0"
|
version = "4.15.0"
|
||||||
@@ -1745,7 +1736,6 @@ dependencies = [
|
|||||||
{ name = "fastmcp" },
|
{ name = "fastmcp" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "pytz" },
|
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "uvicorn", extra = ["standard"] },
|
{ name = "uvicorn", extra = ["standard"] },
|
||||||
{ name = "websockets" },
|
{ name = "websockets" },
|
||||||
@@ -1762,7 +1752,6 @@ dev = [
|
|||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
{ name = "twine" },
|
{ name = "twine" },
|
||||||
{ name = "ty" },
|
{ name = "ty" },
|
||||||
{ name = "types-pytz" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
@@ -1771,7 +1760,6 @@ requires-dist = [
|
|||||||
{ name = "fastmcp", specifier = ">=2.14.5" },
|
{ name = "fastmcp", specifier = ">=2.14.5" },
|
||||||
{ name = "httpx", specifier = ">=0.28.1" },
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.1.1" },
|
{ name = "python-dotenv", specifier = ">=1.1.1" },
|
||||||
{ name = "pytz", specifier = ">=2025.2" },
|
|
||||||
{ name = "rich", specifier = ">=14.1.0" },
|
{ name = "rich", specifier = ">=14.1.0" },
|
||||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.35.0" },
|
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.35.0" },
|
||||||
{ name = "websockets", specifier = ">=15.0.1" },
|
{ name = "websockets", specifier = ">=15.0.1" },
|
||||||
@@ -1788,7 +1776,6 @@ dev = [
|
|||||||
{ name = "ruff", specifier = ">=0.12.8" },
|
{ name = "ruff", specifier = ">=0.12.8" },
|
||||||
{ name = "twine", specifier = ">=6.0.1" },
|
{ name = "twine", specifier = ">=6.0.1" },
|
||||||
{ name = "ty", specifier = ">=0.0.15" },
|
{ name = "ty", specifier = ">=0.0.15" },
|
||||||
{ name = "types-pytz", specifier = ">=2025.2.0.20250809" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
Reference in New Issue
Block a user