Skip to content

Commit

Permalink
feat: handle websocket disconnect gracefully
Browse files Browse the repository at this point in the history
  • Loading branch information
Noza23 committed Apr 16, 2024
1 parent 96d49b6 commit 3decf6a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
9 changes: 6 additions & 3 deletions backend/routers/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union, Any

from fastapi import APIRouter, UploadFile, WebSocket
from fastapi import APIRouter, UploadFile, WebSocket, WebSocketDisconnect
from fastapi import Depends, File
from fastapi import HTTPException, WebSocketException
from redis import asyncio as aioredis # type: ignore
Expand Down Expand Up @@ -114,8 +114,11 @@ async def inference_ws(
if len(myotubes) + len(nucleis) == 0:
await websocket.close()
break

data = await websocket.receive_json()
try:
data = await websocket.receive_json()
except WebSocketDisconnect:
print("Websocket disconnected.")
break
try:
point = Point.model_validate(data)
except ValidationError:
Expand Down
17 changes: 11 additions & 6 deletions backend/routers/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import APIRouter, UploadFile, Depends, WebSocket
from fastapi import HTTPException, WebSocketException
from fastapi import HTTPException, WebSocketException, WebSocketDisconnect
from redis import asyncio as aioredis # type: ignore

from myo_sam.inference.pipeline import Pipeline
Expand Down Expand Up @@ -87,7 +87,6 @@ async def validation_ws(
"""Websocket for validation mode."""

await websocket.accept()
print("Hash Received: ", hash_str)

mo = Myotubes.model_validate_json(
await redis.get(KEYS.result_key(hash_str))
Expand All @@ -98,7 +97,7 @@ async def validation_ws(
i = state.get_next()

if state.done:
websocket.close()
await websocket.close()
return

await websocket.send_json(
Expand All @@ -110,7 +109,14 @@ async def validation_ws(
)

while True:
data = int(await websocket.receive_text())
if state.done:
break
try:
data = int(await websocket.receive_text())
except WebSocketDisconnect:
print("Websocket disconnected.")
break

if data == 0:
state.invalid.add(i)
elif data == 1:
Expand All @@ -128,15 +134,14 @@ async def validation_ws(
step = data != 2 if data != -1 else -1
i = min(max(i + step, 0), len(mo) - 1)

if i == len(mo):
if i + 1 == len(mo):
state.done = True
await redis.mset(
{
KEYS.state_key(hash_str): state.model_dump_json(),
KEYS.result_key(hash_str): mo.model_dump_json(),
}
)
break

# Send next contour
await websocket.send_json(
Expand Down

0 comments on commit 3decf6a

Please sign in to comment.