mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 04:29:17 -07:00
fix(subscriptions): bound snapshot loops with asyncio.timeout, raise on collect errors
- Wrap async-for loops in asyncio.timeout() so both subscribe_once and subscribe_collect
cannot hang indefinitely when no messages arrive after the handshake
- subscribe_once: TimeoutError → ToolError("Subscription timed out after Xs")
- subscribe_collect: TimeoutError → pass (return events collected so far)
- Remove manual deadline checks inside the loops (now redundant)
- subscribe_collect now raises ToolError on GraphQL payload errors instead of silently dropping them
- subscribe_collect handshake now distinguishes connection_error (auth) from unexpected type
This commit is contained in:
@@ -76,25 +76,25 @@ async def subscribe_once(
|
||||
|
||||
# Await first matching data event
|
||||
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
|
||||
async for raw_msg in ws:
|
||||
if asyncio.get_event_loop().time() >= deadline:
|
||||
raise ToolError(f"Subscription timed out after {timeout:.0f}s")
|
||||
|
||||
msg = json.loads(raw_msg)
|
||||
if msg.get("type") == "ping":
|
||||
await ws.send(json.dumps({"type": "pong"}))
|
||||
continue
|
||||
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||
payload = msg.get("payload", {})
|
||||
if errors := payload.get("errors"):
|
||||
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
||||
raise ToolError(f"Subscription errors: {msgs}")
|
||||
if data := payload.get("data"):
|
||||
return data
|
||||
elif msg.get("type") == "error" and msg.get("id") == sub_id:
|
||||
raise ToolError(f"Subscription error: {msg.get('payload')}")
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
async for raw_msg in ws:
|
||||
msg = json.loads(raw_msg)
|
||||
if msg.get("type") == "ping":
|
||||
await ws.send(json.dumps({"type": "pong"}))
|
||||
continue
|
||||
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||
payload = msg.get("payload", {})
|
||||
if errors := payload.get("errors"):
|
||||
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
||||
raise ToolError(f"Subscription errors: {msgs}")
|
||||
if data := payload.get("data"):
|
||||
return data
|
||||
elif msg.get("type") == "error" and msg.get("id") == sub_id:
|
||||
raise ToolError(f"Subscription error: {msg.get('payload')}")
|
||||
except TimeoutError:
|
||||
raise ToolError(f"Subscription timed out after {timeout:.0f}s")
|
||||
|
||||
raise ToolError("WebSocket closed before receiving subscription data")
|
||||
|
||||
@@ -132,8 +132,10 @@ async def subscribe_collect(
|
||||
|
||||
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||
ack = json.loads(raw)
|
||||
if ack.get("type") not in ("connection_ack",):
|
||||
raise ToolError(f"Subscription handshake failed: {ack.get('type')}")
|
||||
if ack.get("type") == "connection_error":
|
||||
raise ToolError(f"Subscription auth failed: {ack.get('payload')}")
|
||||
if ack.get("type") != "connection_ack":
|
||||
raise ToolError(f"Unexpected handshake response: {ack.get('type')}")
|
||||
|
||||
start_type = "subscribe" if proto == "graphql-transport-ws" else "start"
|
||||
await ws.send(
|
||||
@@ -147,19 +149,23 @@ async def subscribe_collect(
|
||||
)
|
||||
|
||||
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||
collect_deadline = asyncio.get_event_loop().time() + collect_for
|
||||
|
||||
async for raw_msg in ws:
|
||||
if asyncio.get_event_loop().time() >= collect_deadline:
|
||||
break
|
||||
msg = json.loads(raw_msg)
|
||||
if msg.get("type") == "ping":
|
||||
await ws.send(json.dumps({"type": "pong"}))
|
||||
continue
|
||||
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||
payload = msg.get("payload", {})
|
||||
if data := payload.get("data"):
|
||||
events.append(data)
|
||||
try:
|
||||
async with asyncio.timeout(collect_for):
|
||||
async for raw_msg in ws:
|
||||
msg = json.loads(raw_msg)
|
||||
if msg.get("type") == "ping":
|
||||
await ws.send(json.dumps({"type": "pong"}))
|
||||
continue
|
||||
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||
payload = msg.get("payload", {})
|
||||
if errors := payload.get("errors"):
|
||||
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
||||
raise ToolError(f"Subscription errors: {msgs}")
|
||||
if data := payload.get("data"):
|
||||
events.append(data)
|
||||
except TimeoutError:
|
||||
pass # Collection window expired — return whatever was collected
|
||||
|
||||
logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window")
|
||||
return events
|
||||
|
||||
Reference in New Issue
Block a user