mirror of
https://github.com/jmagar/unraid-mcp.git
synced 2026-03-23 04:29:17 -07:00
fix(subscriptions): bound snapshot loops with asyncio.timeout, raise on collect errors
- Wrap async-for loops in asyncio.timeout() so both subscribe_once and subscribe_collect
cannot hang indefinitely when no messages arrive after the handshake
- subscribe_once: TimeoutError → ToolError("Subscription timed out after Xs")
- subscribe_collect: TimeoutError → pass (return events collected so far)
- Remove manual deadline checks inside the loops (now redundant)
- subscribe_collect now raises ToolError on GraphQL payload errors instead of silently dropping them
- subscribe_collect handshake now distinguishes connection_error (auth) from unexpected type
This commit is contained in:
@@ -76,25 +76,25 @@ async def subscribe_once(
|
|||||||
|
|
||||||
# Await first matching data event
|
# Await first matching data event
|
||||||
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||||
deadline = asyncio.get_event_loop().time() + timeout
|
|
||||||
|
|
||||||
async for raw_msg in ws:
|
try:
|
||||||
if asyncio.get_event_loop().time() >= deadline:
|
async with asyncio.timeout(timeout):
|
||||||
raise ToolError(f"Subscription timed out after {timeout:.0f}s")
|
async for raw_msg in ws:
|
||||||
|
msg = json.loads(raw_msg)
|
||||||
msg = json.loads(raw_msg)
|
if msg.get("type") == "ping":
|
||||||
if msg.get("type") == "ping":
|
await ws.send(json.dumps({"type": "pong"}))
|
||||||
await ws.send(json.dumps({"type": "pong"}))
|
continue
|
||||||
continue
|
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||||
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
payload = msg.get("payload", {})
|
||||||
payload = msg.get("payload", {})
|
if errors := payload.get("errors"):
|
||||||
if errors := payload.get("errors"):
|
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
||||||
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
raise ToolError(f"Subscription errors: {msgs}")
|
||||||
raise ToolError(f"Subscription errors: {msgs}")
|
if data := payload.get("data"):
|
||||||
if data := payload.get("data"):
|
return data
|
||||||
return data
|
elif msg.get("type") == "error" and msg.get("id") == sub_id:
|
||||||
elif msg.get("type") == "error" and msg.get("id") == sub_id:
|
raise ToolError(f"Subscription error: {msg.get('payload')}")
|
||||||
raise ToolError(f"Subscription error: {msg.get('payload')}")
|
except TimeoutError:
|
||||||
|
raise ToolError(f"Subscription timed out after {timeout:.0f}s")
|
||||||
|
|
||||||
raise ToolError("WebSocket closed before receiving subscription data")
|
raise ToolError("WebSocket closed before receiving subscription data")
|
||||||
|
|
||||||
@@ -132,8 +132,10 @@ async def subscribe_collect(
|
|||||||
|
|
||||||
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||||
ack = json.loads(raw)
|
ack = json.loads(raw)
|
||||||
if ack.get("type") not in ("connection_ack",):
|
if ack.get("type") == "connection_error":
|
||||||
raise ToolError(f"Subscription handshake failed: {ack.get('type')}")
|
raise ToolError(f"Subscription auth failed: {ack.get('payload')}")
|
||||||
|
if ack.get("type") != "connection_ack":
|
||||||
|
raise ToolError(f"Unexpected handshake response: {ack.get('type')}")
|
||||||
|
|
||||||
start_type = "subscribe" if proto == "graphql-transport-ws" else "start"
|
start_type = "subscribe" if proto == "graphql-transport-ws" else "start"
|
||||||
await ws.send(
|
await ws.send(
|
||||||
@@ -147,19 +149,23 @@ async def subscribe_collect(
|
|||||||
)
|
)
|
||||||
|
|
||||||
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
expected_type = "next" if proto == "graphql-transport-ws" else "data"
|
||||||
collect_deadline = asyncio.get_event_loop().time() + collect_for
|
|
||||||
|
|
||||||
async for raw_msg in ws:
|
try:
|
||||||
if asyncio.get_event_loop().time() >= collect_deadline:
|
async with asyncio.timeout(collect_for):
|
||||||
break
|
async for raw_msg in ws:
|
||||||
msg = json.loads(raw_msg)
|
msg = json.loads(raw_msg)
|
||||||
if msg.get("type") == "ping":
|
if msg.get("type") == "ping":
|
||||||
await ws.send(json.dumps({"type": "pong"}))
|
await ws.send(json.dumps({"type": "pong"}))
|
||||||
continue
|
continue
|
||||||
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
if msg.get("type") == expected_type and msg.get("id") == sub_id:
|
||||||
payload = msg.get("payload", {})
|
payload = msg.get("payload", {})
|
||||||
if data := payload.get("data"):
|
if errors := payload.get("errors"):
|
||||||
events.append(data)
|
msgs = "; ".join(e.get("message", str(e)) for e in errors)
|
||||||
|
raise ToolError(f"Subscription errors: {msgs}")
|
||||||
|
if data := payload.get("data"):
|
||||||
|
events.append(data)
|
||||||
|
except TimeoutError:
|
||||||
|
pass # Collection window expired — return whatever was collected
|
||||||
|
|
||||||
logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window")
|
logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window")
|
||||||
return events
|
return events
|
||||||
|
|||||||
Reference in New Issue
Block a user