feat: harden API safety and expand command docs with full test coverage

This commit is contained in:
Jacob Magar
2026-02-15 22:15:51 -05:00
parent d791c6b6b7
commit abb7915672
60 changed files with 7122 additions and 1247 deletions

View File

@@ -30,7 +30,9 @@ class SubscriptionManager:
self.subscription_lock = asyncio.Lock()
# Configuration
self.auto_start_enabled = os.getenv("UNRAID_AUTO_START_SUBSCRIPTIONS", "true").lower() == "true"
self.auto_start_enabled = (
os.getenv("UNRAID_AUTO_START_SUBSCRIPTIONS", "true").lower() == "true"
)
self.reconnect_attempts: dict[str, int] = {}
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
self.connection_states: dict[str, str] = {} # Track connection state per subscription
@@ -50,12 +52,16 @@ class SubscriptionManager:
""",
"resource": "unraid://logs/stream",
"description": "Real-time log file streaming",
"auto_start": False # Started manually with path parameter
"auto_start": False, # Started manually with path parameter
}
}
logger.info(f"[SUBSCRIPTION_MANAGER] Initialized with auto_start={self.auto_start_enabled}, max_reconnects={self.max_reconnect_attempts}")
logger.debug(f"[SUBSCRIPTION_MANAGER] Available subscriptions: {list(self.subscription_configs.keys())}")
logger.info(
f"[SUBSCRIPTION_MANAGER] Initialized with auto_start={self.auto_start_enabled}, max_reconnects={self.max_reconnect_attempts}"
)
logger.debug(
f"[SUBSCRIPTION_MANAGER] Available subscriptions: {list(self.subscription_configs.keys())}"
)
async def auto_start_all_subscriptions(self) -> None:
"""Auto-start all subscriptions marked for auto-start."""
@@ -69,21 +75,31 @@ class SubscriptionManager:
for subscription_name, config in self.subscription_configs.items():
if config.get("auto_start", False):
try:
logger.info(f"[SUBSCRIPTION_MANAGER] Auto-starting subscription: {subscription_name}")
logger.info(
f"[SUBSCRIPTION_MANAGER] Auto-starting subscription: {subscription_name}"
)
await self.start_subscription(subscription_name, str(config["query"]))
auto_start_count += 1
except Exception as e:
logger.error(f"[SUBSCRIPTION_MANAGER] Failed to auto-start {subscription_name}: {e}")
logger.error(
f"[SUBSCRIPTION_MANAGER] Failed to auto-start {subscription_name}: {e}"
)
self.last_error[subscription_name] = str(e)
logger.info(f"[SUBSCRIPTION_MANAGER] Auto-start completed. Started {auto_start_count} subscriptions")
logger.info(
f"[SUBSCRIPTION_MANAGER] Auto-start completed. Started {auto_start_count} subscriptions"
)
async def start_subscription(self, subscription_name: str, query: str, variables: dict[str, Any] | None = None) -> None:
async def start_subscription(
self, subscription_name: str, query: str, variables: dict[str, Any] | None = None
) -> None:
"""Start a GraphQL subscription and maintain it as a resource."""
logger.info(f"[SUBSCRIPTION:{subscription_name}] Starting subscription...")
if subscription_name in self.active_subscriptions:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] Subscription already active, skipping")
logger.warning(
f"[SUBSCRIPTION:{subscription_name}] Subscription already active, skipping"
)
return
# Reset connection tracking
@@ -92,12 +108,18 @@ class SubscriptionManager:
async with self.subscription_lock:
try:
task = asyncio.create_task(self._subscription_loop(subscription_name, query, variables or {}))
task = asyncio.create_task(
self._subscription_loop(subscription_name, query, variables or {})
)
self.active_subscriptions[subscription_name] = task
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription task created and started")
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription task created and started"
)
self.connection_states[subscription_name] = "active"
except Exception as e:
logger.error(f"[SUBSCRIPTION:{subscription_name}] Failed to start subscription task: {e}")
logger.error(
f"[SUBSCRIPTION:{subscription_name}] Failed to start subscription task: {e}"
)
self.connection_states[subscription_name] = "failed"
self.last_error[subscription_name] = str(e)
raise
@@ -120,7 +142,9 @@ class SubscriptionManager:
else:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
async def _subscription_loop(self, subscription_name: str, query: str, variables: dict[str, Any] | None) -> None:
async def _subscription_loop(
self, subscription_name: str, query: str, variables: dict[str, Any] | None
) -> None:
"""Main loop for maintaining a GraphQL subscription with comprehensive logging."""
retry_delay: int | float = 5
max_retry_delay = 300 # 5 minutes max
@@ -129,10 +153,14 @@ class SubscriptionManager:
attempt = self.reconnect_attempts.get(subscription_name, 0) + 1
self.reconnect_attempts[subscription_name] = attempt
logger.info(f"[WEBSOCKET:{subscription_name}] Connection attempt #{attempt} (max: {self.max_reconnect_attempts})")
logger.info(
f"[WEBSOCKET:{subscription_name}] Connection attempt #{attempt} (max: {self.max_reconnect_attempts})"
)
if attempt > self.max_reconnect_attempts:
logger.error(f"[WEBSOCKET:{subscription_name}] Max reconnection attempts ({self.max_reconnect_attempts}) exceeded, stopping")
logger.error(
f"[WEBSOCKET:{subscription_name}] Max reconnection attempts ({self.max_reconnect_attempts}) exceeded, stopping"
)
self.connection_states[subscription_name] = "max_retries_exceeded"
break
@@ -142,9 +170,9 @@ class SubscriptionManager:
raise ValueError("UNRAID_API_URL is not configured")
if UNRAID_API_URL.startswith("https://"):
ws_url = "wss://" + UNRAID_API_URL[len("https://"):]
ws_url = "wss://" + UNRAID_API_URL[len("https://") :]
elif UNRAID_API_URL.startswith("http://"):
ws_url = "ws://" + UNRAID_API_URL[len("http://"):]
ws_url = "ws://" + UNRAID_API_URL[len("http://") :]
else:
ws_url = UNRAID_API_URL
@@ -152,13 +180,17 @@ class SubscriptionManager:
ws_url = ws_url.rstrip("/") + "/graphql"
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
logger.debug(f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}")
logger.debug(
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if UNRAID_API_KEY else 'No'}"
)
ssl_context = build_ws_ssl_context(ws_url)
# Connection with timeout
connect_timeout = 10
logger.debug(f"[WEBSOCKET:{subscription_name}] Connection timeout: {connect_timeout}s")
logger.debug(
f"[WEBSOCKET:{subscription_name}] Connection timeout: {connect_timeout}s"
)
async with websockets.connect(
ws_url,
@@ -166,11 +198,12 @@ class SubscriptionManager:
ping_interval=20,
ping_timeout=10,
close_timeout=10,
ssl=ssl_context
ssl=ssl_context,
) as websocket:
selected_proto = websocket.subprotocol or "none"
logger.info(f"[WEBSOCKET:{subscription_name}] Connected! Protocol: {selected_proto}")
logger.info(
f"[WEBSOCKET:{subscription_name}] Connected! Protocol: {selected_proto}"
)
self.connection_states[subscription_name] = "connected"
# Reset retry count on successful connection
@@ -178,21 +211,21 @@ class SubscriptionManager:
retry_delay = 5 # Reset delay
# Initialize GraphQL-WS protocol
logger.debug(f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol...")
logger.debug(
f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..."
)
init_type = "connection_init"
init_payload: dict[str, Any] = {"type": init_type}
if UNRAID_API_KEY:
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
# Use standard X-API-Key header format (matching HTTP client)
auth_payload = {
"headers": {
"X-API-Key": UNRAID_API_KEY
}
}
auth_payload = {"headers": {"X-API-Key": UNRAID_API_KEY}}
init_payload["payload"] = auth_payload
else:
logger.warning(f"[AUTH:{subscription_name}] No API key available for authentication")
logger.warning(
f"[AUTH:{subscription_name}] No API key available for authentication"
)
logger.debug(f"[PROTOCOL:{subscription_name}] Sending connection_init message")
await websocket.send(json.dumps(init_payload))
@@ -203,45 +236,66 @@ class SubscriptionManager:
try:
init_data = json.loads(init_raw)
logger.debug(f"[PROTOCOL:{subscription_name}] Received init response: {init_data.get('type')}")
logger.debug(
f"[PROTOCOL:{subscription_name}] Received init response: {init_data.get('type')}"
)
except json.JSONDecodeError as e:
init_preview = init_raw[:200] if isinstance(init_raw, str) else init_raw[:200].decode("utf-8", errors="replace")
logger.error(f"[PROTOCOL:{subscription_name}] Failed to decode init response: {init_preview}...")
init_preview = (
init_raw[:200]
if isinstance(init_raw, str)
else init_raw[:200].decode("utf-8", errors="replace")
)
logger.error(
f"[PROTOCOL:{subscription_name}] Failed to decode init response: {init_preview}..."
)
self.last_error[subscription_name] = f"Invalid JSON in init response: {e}"
break
# Handle connection acknowledgment
if init_data.get("type") == "connection_ack":
logger.info(f"[PROTOCOL:{subscription_name}] Connection acknowledged successfully")
logger.info(
f"[PROTOCOL:{subscription_name}] Connection acknowledged successfully"
)
self.connection_states[subscription_name] = "authenticated"
elif init_data.get("type") == "connection_error":
error_payload = init_data.get("payload", {})
logger.error(f"[AUTH:{subscription_name}] Authentication failed: {error_payload}")
self.last_error[subscription_name] = f"Authentication error: {error_payload}"
logger.error(
f"[AUTH:{subscription_name}] Authentication failed: {error_payload}"
)
self.last_error[subscription_name] = (
f"Authentication error: {error_payload}"
)
self.connection_states[subscription_name] = "auth_failed"
break
else:
logger.warning(f"[PROTOCOL:{subscription_name}] Unexpected init response: {init_data}")
logger.warning(
f"[PROTOCOL:{subscription_name}] Unexpected init response: {init_data}"
)
# Continue anyway - some servers send other messages first
# Start the subscription
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Starting GraphQL subscription...")
start_type = "subscribe" if selected_proto == "graphql-transport-ws" else "start"
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Starting GraphQL subscription..."
)
start_type = (
"subscribe" if selected_proto == "graphql-transport-ws" else "start"
)
subscription_message = {
"id": subscription_name,
"type": start_type,
"payload": {
"query": query,
"variables": variables
}
"payload": {"query": query, "variables": variables},
}
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}")
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
)
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Variables: {variables}")
await websocket.send(json.dumps(subscription_message))
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription started successfully")
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription started successfully"
)
self.connection_states[subscription_name] = "subscribed"
# Listen for subscription data
@@ -253,57 +307,100 @@ class SubscriptionManager:
message_count += 1
message_type = data.get("type", "unknown")
logger.debug(f"[DATA:{subscription_name}] Message #{message_count}: {message_type}")
logger.debug(
f"[DATA:{subscription_name}] Message #{message_count}: {message_type}"
)
# Handle different message types
expected_data_type = "next" if selected_proto == "graphql-transport-ws" else "data"
expected_data_type = (
"next" if selected_proto == "graphql-transport-ws" else "data"
)
if data.get("type") == expected_data_type and data.get("id") == subscription_name:
if (
data.get("type") == expected_data_type
and data.get("id") == subscription_name
):
payload = data.get("payload", {})
if payload.get("data"):
logger.info(f"[DATA:{subscription_name}] Received subscription data update")
logger.info(
f"[DATA:{subscription_name}] Received subscription data update"
)
self.resource_data[subscription_name] = SubscriptionData(
data=payload["data"],
last_updated=datetime.now(),
subscription_type=subscription_name
subscription_type=subscription_name,
)
logger.debug(
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
)
logger.debug(f"[RESOURCE:{subscription_name}] Resource data updated successfully")
elif payload.get("errors"):
logger.error(f"[DATA:{subscription_name}] GraphQL errors in response: {payload['errors']}")
self.last_error[subscription_name] = f"GraphQL errors: {payload['errors']}"
logger.error(
f"[DATA:{subscription_name}] GraphQL errors in response: {payload['errors']}"
)
self.last_error[subscription_name] = (
f"GraphQL errors: {payload['errors']}"
)
else:
logger.warning(f"[DATA:{subscription_name}] Empty or invalid data payload: {payload}")
logger.warning(
f"[DATA:{subscription_name}] Empty or invalid data payload: {payload}"
)
elif data.get("type") == "ping":
logger.debug(f"[PROTOCOL:{subscription_name}] Received ping, sending pong")
logger.debug(
f"[PROTOCOL:{subscription_name}] Received ping, sending pong"
)
await websocket.send(json.dumps({"type": "pong"}))
elif data.get("type") == "error":
error_payload = data.get("payload", {})
logger.error(f"[SUBSCRIPTION:{subscription_name}] Subscription error: {error_payload}")
self.last_error[subscription_name] = f"Subscription error: {error_payload}"
logger.error(
f"[SUBSCRIPTION:{subscription_name}] Subscription error: {error_payload}"
)
self.last_error[subscription_name] = (
f"Subscription error: {error_payload}"
)
self.connection_states[subscription_name] = "error"
elif data.get("type") == "complete":
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription completed by server")
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription completed by server"
)
self.connection_states[subscription_name] = "completed"
break
elif data.get("type") in ["ka", "ping", "pong"]:
logger.debug(f"[PROTOCOL:{subscription_name}] Keepalive message: {message_type}")
logger.debug(
f"[PROTOCOL:{subscription_name}] Keepalive message: {message_type}"
)
else:
logger.debug(f"[PROTOCOL:{subscription_name}] Unhandled message type: {message_type}")
logger.debug(
f"[PROTOCOL:{subscription_name}] Unhandled message type: {message_type}"
)
except json.JSONDecodeError as e:
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode("utf-8", errors="replace")
logger.error(f"[PROTOCOL:{subscription_name}] Failed to decode message: {msg_preview}...")
msg_preview = (
message[:200]
if isinstance(message, str)
else message[:200].decode("utf-8", errors="replace")
)
logger.error(
f"[PROTOCOL:{subscription_name}] Failed to decode message: {msg_preview}..."
)
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e:
logger.error(f"[DATA:{subscription_name}] Error processing message: {e}")
msg_preview = message[:200] if isinstance(message, str) else message[:200].decode("utf-8", errors="replace")
logger.debug(f"[DATA:{subscription_name}] Raw message: {msg_preview}...")
logger.error(
f"[DATA:{subscription_name}] Error processing message: {e}"
)
msg_preview = (
message[:200]
if isinstance(message, str)
else message[:200].decode("utf-8", errors="replace")
)
logger.debug(
f"[DATA:{subscription_name}] Raw message: {msg_preview}..."
)
except TimeoutError:
error_msg = "Connection or authentication timeout"
@@ -332,7 +429,9 @@ class SubscriptionManager:
# Calculate backoff delay
retry_delay = min(retry_delay * 1.5, max_retry_delay)
logger.info(f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds...")
logger.info(
f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds..."
)
self.connection_states[subscription_name] = "reconnecting"
await asyncio.sleep(retry_delay)
@@ -363,14 +462,14 @@ class SubscriptionManager:
"config": {
"resource": config["resource"],
"description": config["description"],
"auto_start": config.get("auto_start", False)
"auto_start": config.get("auto_start", False),
},
"runtime": {
"active": sub_name in self.active_subscriptions,
"connection_state": self.connection_states.get(sub_name, "not_started"),
"reconnect_attempts": self.reconnect_attempts.get(sub_name, 0),
"last_error": self.last_error.get(sub_name, None)
}
"last_error": self.last_error.get(sub_name, None),
},
}
# Add data info if available
@@ -380,7 +479,7 @@ class SubscriptionManager:
sub_status["data"] = {
"available": True,
"last_updated": data_info.last_updated.isoformat(),
"age_seconds": age_seconds
"age_seconds": age_seconds,
}
else:
sub_status["data"] = {"available": False}