fix(subscriptions): bound snapshot loops with asyncio.timeout, raise on collect errors

- Wrap async-for loops in asyncio.timeout() so both subscribe_once and subscribe_collect
  cannot hang indefinitely when no messages arrive after the handshake
- subscribe_once: TimeoutError → ToolError("Subscription timed out after Xs")
- subscribe_collect: TimeoutError → pass (return events collected so far)
- Remove manual deadline checks inside the loops (now redundant)
- subscribe_collect now raises ToolError on GraphQL payload errors instead of silently dropping them
- subscribe_collect handshake now distinguishes connection_error (auth) from unexpected type
This commit is contained in:
Jacob Magar
2026-03-15 18:52:27 -04:00
parent 181ad53414
commit 5a3e8e285b

View File

@@ -76,25 +76,25 @@ async def subscribe_once(
# Await first matching data event # Await first matching data event
expected_type = "next" if proto == "graphql-transport-ws" else "data" expected_type = "next" if proto == "graphql-transport-ws" else "data"
deadline = asyncio.get_event_loop().time() + timeout
async for raw_msg in ws: try:
if asyncio.get_event_loop().time() >= deadline: async with asyncio.timeout(timeout):
raise ToolError(f"Subscription timed out after {timeout:.0f}s") async for raw_msg in ws:
msg = json.loads(raw_msg)
msg = json.loads(raw_msg) if msg.get("type") == "ping":
if msg.get("type") == "ping": await ws.send(json.dumps({"type": "pong"}))
await ws.send(json.dumps({"type": "pong"})) continue
continue if msg.get("type") == expected_type and msg.get("id") == sub_id:
if msg.get("type") == expected_type and msg.get("id") == sub_id: payload = msg.get("payload", {})
payload = msg.get("payload", {}) if errors := payload.get("errors"):
if errors := payload.get("errors"): msgs = "; ".join(e.get("message", str(e)) for e in errors)
msgs = "; ".join(e.get("message", str(e)) for e in errors) raise ToolError(f"Subscription errors: {msgs}")
raise ToolError(f"Subscription errors: {msgs}") if data := payload.get("data"):
if data := payload.get("data"): return data
return data elif msg.get("type") == "error" and msg.get("id") == sub_id:
elif msg.get("type") == "error" and msg.get("id") == sub_id: raise ToolError(f"Subscription error: {msg.get('payload')}")
raise ToolError(f"Subscription error: {msg.get('payload')}") except TimeoutError:
raise ToolError(f"Subscription timed out after {timeout:.0f}s")
raise ToolError("WebSocket closed before receiving subscription data") raise ToolError("WebSocket closed before receiving subscription data")
@@ -132,8 +132,10 @@ async def subscribe_collect(
raw = await asyncio.wait_for(ws.recv(), timeout=timeout) raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
ack = json.loads(raw) ack = json.loads(raw)
if ack.get("type") not in ("connection_ack",): if ack.get("type") == "connection_error":
raise ToolError(f"Subscription handshake failed: {ack.get('type')}") raise ToolError(f"Subscription auth failed: {ack.get('payload')}")
if ack.get("type") != "connection_ack":
raise ToolError(f"Unexpected handshake response: {ack.get('type')}")
start_type = "subscribe" if proto == "graphql-transport-ws" else "start" start_type = "subscribe" if proto == "graphql-transport-ws" else "start"
await ws.send( await ws.send(
@@ -147,19 +149,23 @@ async def subscribe_collect(
) )
expected_type = "next" if proto == "graphql-transport-ws" else "data" expected_type = "next" if proto == "graphql-transport-ws" else "data"
collect_deadline = asyncio.get_event_loop().time() + collect_for
async for raw_msg in ws: try:
if asyncio.get_event_loop().time() >= collect_deadline: async with asyncio.timeout(collect_for):
break async for raw_msg in ws:
msg = json.loads(raw_msg) msg = json.loads(raw_msg)
if msg.get("type") == "ping": if msg.get("type") == "ping":
await ws.send(json.dumps({"type": "pong"})) await ws.send(json.dumps({"type": "pong"}))
continue continue
if msg.get("type") == expected_type and msg.get("id") == sub_id: if msg.get("type") == expected_type and msg.get("id") == sub_id:
payload = msg.get("payload", {}) payload = msg.get("payload", {})
if data := payload.get("data"): if errors := payload.get("errors"):
events.append(data) msgs = "; ".join(e.get("message", str(e)) for e in errors)
raise ToolError(f"Subscription errors: {msgs}")
if data := payload.get("data"):
events.append(data)
except TimeoutError:
pass # Collection window expired — return whatever was collected
logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window") logger.debug(f"[SNAPSHOT] Collected {len(events)} events in {collect_for}s window")
return events return events