diff --git a/tornado/websocket.py b/tornado/websocket.py index d2c6a427aa..edc6993c0c 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -226,6 +226,8 @@ def __init__( self.close_code = None # type: Optional[int] self.close_reason = None # type: Optional[str] self.stream = None # type: Optional[IOStream] + self._opening = False + self._need_close = False self._on_close_called = False async def get(self, *args: Any, **kwargs: Any) -> None: @@ -563,7 +565,9 @@ def on_connection_close(self) -> None: if self.ws_connection: self.ws_connection.on_connection_close() self.ws_connection = None - if not self._on_close_called: + if self._opening: + self._need_close = True + elif not self._on_close_called: self._on_close_called = True self.on_close() self._break_cycles() @@ -950,9 +954,13 @@ async def _accept_connection(self, handler: WebSocketHandler) -> None: self.start_pinging() try: + handler._opening = True open_result = handler.open(*handler.open_args, **handler.open_kwargs) if open_result is not None: await open_result + handler._opening = False + if handler._need_close: + handler.on_connection_close() except Exception: handler.log_exception(*sys.exc_info()) self._abort()