diff --git a/unraid_mcp/subscriptions/snapshot.py b/unraid_mcp/subscriptions/snapshot.py index 839a9cb..811141c 100644 --- a/unraid_mcp/subscriptions/snapshot.py +++ b/unraid_mcp/subscriptions/snapshot.py @@ -76,25 +76,25 @@ async def subscribe_once( # Await first matching data event expected_type = "next" if proto == "graphql-transport-ws" else "data" - deadline = asyncio.get_event_loop().time() + timeout - async for raw_msg in ws: - if asyncio.get_event_loop().time() >= deadline: - raise ToolError(f"Subscription timed out after {timeout:.0f}s") - - msg = json.loads(raw_msg) - if msg.get("type") == "ping": - await ws.send(json.dumps({"type": "pong"})) - continue - if msg.get("type") == expected_type and msg.get("id") == sub_id: - payload = msg.get("payload", {}) - if errors := payload.get("errors"): - msgs = "; ".join(e.get("message", str(e)) for e in errors) - raise ToolError(f"Subscription errors: {msgs}") - if data := payload.get("data"): - return data - elif msg.get("type") == "error" and msg.get("id") == sub_id: - raise ToolError(f"Subscription error: {msg.get('payload')}") + try: + async with asyncio.timeout(timeout): + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("type") == "ping": + await ws.send(json.dumps({"type": "pong"})) + continue + if msg.get("type") == expected_type and msg.get("id") == sub_id: + payload = msg.get("payload", {}) + if errors := payload.get("errors"): + msgs = "; ".join(e.get("message", str(e)) for e in errors) + raise ToolError(f"Subscription errors: {msgs}") + if data := payload.get("data"): + return data + elif msg.get("type") == "error" and msg.get("id") == sub_id: + raise ToolError(f"Subscription error: {msg.get('payload')}") + except TimeoutError: + raise ToolError(f"Subscription timed out after {timeout:.0f}s") raise ToolError("WebSocket closed before receiving subscription data") @@ -132,8 +132,10 @@ async def subscribe_collect( raw = await asyncio.wait_for(ws.recv(), timeout=timeout) ack = json.loads(raw) - if ack.get("type") not in ("connection_ack",): - raise ToolError(f"Subscription handshake failed: {ack.get('type')}") + if ack.get("type") == "connection_error": + raise ToolError(f"Subscription auth failed: {ack.get('payload')}") + if ack.get("type") != "connection_ack": + raise ToolError(f"Unexpected handshake response: {ack.get('type')}") start_type = "subscribe" if proto == "graphql-transport-ws" else "start" await ws.send( @@ -147,19 +149,23 @@ async def subscribe_collect( ) expected_type = "next" if proto == "graphql-transport-ws" else "data" - collect_deadline = asyncio.get_event_loop().time() + collect_for - async for raw_msg in ws: - if asyncio.get_event_loop().time() >= collect_deadline: - break - msg = json.loads(raw_msg) - if msg.get("type") == "ping": - await ws.send(json.dumps({"type": "pong"})) - continue - if msg.get("type") == expected_type and msg.get("id") == sub_id: - payload = msg.get("payload", {}) - if data := payload.get("data"): - events.append(data) + try: + async with asyncio.timeout(collect_for): + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("type") == "ping": + await ws.send(json.dumps({"type": "pong"})) + continue + if msg.get("type") == expected_type and msg.get("id") == sub_id: + payload = msg.get("payload", {}) + if errors := payload.get("errors"): + msgs = "; ".join(e.get("message", str(e)) for e in errors) + raise ToolError(f"Subscription errors: {msgs}") + if data := payload.get("data"): + events.append(data) + except TimeoutError: + pass # Collection window expired — return whatever was collected logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window") return events