# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0
"""WebSocket connection handling for ASGI applications.
Classes:
WebSocketState — IntEnum: CONNECTING(0), CONNECTED(1), DISCONNECTED(2)
WebSocket — wraps ASGI receive/send with Pythonic API
Flow: WebSocket(scope, receive, send) → accept() → send/receive → close()
Strict typing: receive_text() rejects bytes, receive_bytes() rejects text.
Optional TYTX support: receive_typed()/send_typed() for typed serialization.
"""
from __future__ import annotations
import json
from collections.abc import Mapping
from enum import IntEnum
from typing import TYPE_CHECKING, Any, AsyncIterator
from .datastructures import (
Address,
Headers,
QueryParams,
State,
URL,
headers_from_scope,
query_params_from_scope,
)
from .exceptions import WebSocketDisconnect
from .types import Message, Receive, Scope, Send
__all__ = [
"WebSocket",
"WebSocketState",
]
# Optional dependency: orjson for faster JSON
try:
import orjson
HAS_ORJSON = True
except ImportError:
orjson = None # type: ignore[assignment]
HAS_ORJSON = False
# Optional dependency: genro-tytx for typed serialization
try:
from genro_tytx import from_tytx, to_tytx
HAS_TYTX = True
except ImportError:
HAS_TYTX = False
from_tytx = None # type: ignore[assignment]
to_tytx = None # type: ignore[assignment]
if TYPE_CHECKING:
pass
[docs]
class WebSocketState(IntEnum):
"""
WebSocket connection state.
Represents the three possible states of a WebSocket connection:
- CONNECTING: Initial state, before accept() is called
- CONNECTED: After accept(), can send and receive messages
- DISCONNECTED: After close() or client disconnect
Example:
>>> ws = WebSocket(scope, receive, send)
>>> ws.connection_state == WebSocketState.CONNECTING
True
>>> await ws.accept()
>>> ws.connection_state == WebSocketState.CONNECTED
True
"""
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
[docs]
class WebSocket:
"""
WebSocket connection wrapper for ASGI.
Provides a Pythonic interface for handling WebSocket connections through
the ASGI protocol. Manages connection lifecycle (accept, close) and
message sending/receiving.
Attributes:
scope: The ASGI scope dictionary.
connection_state: Current WebSocketState.
Example:
>>> async def handler(scope, receive, send):
... ws = WebSocket(scope, receive, send)
... await ws.accept()
... message = await ws.receive_text()
... await ws.send_text(f"Echo: {message}")
... await ws.close()
"""
__slots__ = (
"_scope",
"_receive",
"_send",
"_connection_state",
"_headers",
"_query_params",
"_url",
"_state",
"_accepted_subprotocol",
)
[docs]
def __init__(
self,
scope: Scope,
receive: Receive,
send: Send,
) -> None:
"""
Initialize WebSocket connection wrapper.
Args:
scope: ASGI WebSocket scope dictionary.
receive: ASGI receive callable.
send: ASGI send callable.
Raises:
ValueError: If scope type is not "websocket".
"""
if scope.get("type") != "websocket":
raise ValueError(f"Expected scope type 'websocket', got '{scope.get('type')}'")
self._scope = scope
self._receive = receive
self._send = send
self._connection_state = WebSocketState.CONNECTING
self._headers: Headers | None = None
self._query_params: QueryParams | None = None
self._url: URL | None = None
self._state: State | None = None
self._accepted_subprotocol: str | None = None
# =========================================================================
# Properties
# =========================================================================
@property
def scope(self) -> Scope:
"""The raw ASGI scope dictionary."""
return self._scope
@property
def asgi_receive(self) -> Receive:
"""The ASGI receive callable."""
return self._receive
@property
def asgi_send(self) -> Send:
"""The ASGI send callable."""
return self._send
@property
def connection_state(self) -> WebSocketState:
"""Current connection state."""
return self._connection_state
@property
def path(self) -> str:
"""URL path component."""
return str(self._scope.get("path", "/"))
@property
def scheme(self) -> str:
"""URL scheme ('ws' or 'wss')."""
return str(self._scope.get("scheme", "ws"))
@property
def url(self) -> URL:
"""
Full WebSocket URL.
Constructed lazily from scope fields: scheme, server, root_path,
path, and query_string.
Returns:
URL object representing the complete WebSocket URL.
"""
if self._url is None:
scheme = self.scheme
server = self._scope.get("server")
path = self._scope.get("root_path", "") + self.path
query_string = self._scope.get("query_string", b"")
if server:
host, port = server
# Omit default ports
if (scheme == "ws" and port == 80) or (scheme == "wss" and port == 443):
netloc = host
else:
netloc = f"{host}:{port}"
else:
netloc = self.headers.get("host", "localhost")
url_str = f"{scheme}://{netloc}{path}"
if query_string:
url_str += f"?{query_string.decode('latin-1')}"
self._url = URL(url_str)
return self._url
@property
def headers(self) -> Headers:
"""Connection headers (case-insensitive)."""
if self._headers is None:
self._headers = headers_from_scope(self._scope)
return self._headers
@property
def query_params(self) -> QueryParams:
"""Query string parameters."""
if self._query_params is None:
self._query_params = query_params_from_scope(self._scope)
return self._query_params
@property
def state(self) -> State:
"""Per-connection state container."""
if self._state is None:
self._state = State()
return self._state
@property
def client(self) -> Address | None:
"""Client address (host, port) if available."""
client = self._scope.get("client")
if client:
return Address(client[0], client[1])
return None
@property
def subprotocols(self) -> tuple[str, ...]:
"""Subprotocols requested by the client (immutable)."""
return tuple(self._scope.get("subprotocols", []))
@property
def accepted_subprotocol(self) -> str | None:
"""The subprotocol selected in accept(), or None."""
return self._accepted_subprotocol
# =========================================================================
# Connection Lifecycle
# =========================================================================
[docs]
async def accept(
self,
subprotocol: str | None = None,
headers: Mapping[str, str] | None = None,
) -> None:
"""
Accept the WebSocket connection.
Must be called before sending or receiving messages. Consumes the
``websocket.connect`` message and sends ``websocket.accept``.
Args:
subprotocol: Optional subprotocol to negotiate.
headers: Optional headers to include in accept response.
Raises:
RuntimeError: If not in CONNECTING state or unexpected message.
"""
if self._connection_state != WebSocketState.CONNECTING:
raise RuntimeError(f"Cannot accept: connection in {self._connection_state.name} state")
# Consume the websocket.connect message
message = await self._receive()
if message["type"] != "websocket.connect":
raise RuntimeError(f"Expected websocket.connect, got {message['type']}")
# Build accept message
accept_message: Message = {"type": "websocket.accept"}
if subprotocol is not None:
accept_message["subprotocol"] = subprotocol
self._accepted_subprotocol = subprotocol
if headers is not None:
accept_message["headers"] = [
(k.encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()
]
await self._send(accept_message)
self._connection_state = WebSocketState.CONNECTED
[docs]
async def close(
self,
code: int = 1000,
reason: str = "",
) -> None:
"""
Close the WebSocket connection.
Idempotent: safe to call multiple times.
Args:
code: WebSocket close code (default 1000).
reason: Optional close reason.
Raises:
RuntimeError: If connection not accepted yet.
"""
if self._connection_state == WebSocketState.DISCONNECTED:
return # Idempotent
if self._connection_state == WebSocketState.CONNECTING:
raise RuntimeError("Cannot close: connection not accepted yet")
await self._send(
{
"type": "websocket.close",
"code": code,
"reason": reason,
}
)
self._connection_state = WebSocketState.DISCONNECTED
# =========================================================================
# Receiving Messages
# =========================================================================
async def _receive_message(self) -> Message:
"""
Internal: receive message with state validation.
Returns:
The ASGI message dict.
Raises:
RuntimeError: If not connected.
WebSocketDisconnect: If client disconnected.
"""
if self._connection_state != WebSocketState.CONNECTED:
raise RuntimeError(f"Cannot receive: connection in {self._connection_state.name} state")
message = await self._receive()
if message["type"] == "websocket.disconnect":
self._connection_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(
code=message.get("code", 1000),
reason=message.get("reason", ""),
)
return message
[docs]
async def receive_text(self) -> str:
"""
Receive a text message.
Returns:
The text message content.
Raises:
RuntimeError: If not connected.
TypeError: If received binary data.
WebSocketDisconnect: If client disconnected.
"""
message = await self._receive_message()
if "bytes" in message and message["bytes"] is not None:
raise TypeError("Received binary message. Use receive_bytes() for binary data.")
text: str = message.get("text", "")
return text
[docs]
async def receive_bytes(self) -> bytes:
"""
Receive a binary message.
Returns:
The binary message content.
Raises:
RuntimeError: If not connected.
TypeError: If received text data.
WebSocketDisconnect: If client disconnected.
"""
message = await self._receive_message()
if "text" in message and message["text"] is not None:
raise TypeError("Received text message. Use receive_text() for text data.")
data: bytes = message.get("bytes", b"")
return data
[docs]
async def receive_json(self) -> Any:
"""
Receive and parse a JSON text message.
Returns:
The parsed JSON value.
Raises:
RuntimeError: If not connected.
TypeError: If received binary data.
json.JSONDecodeError: If not valid JSON.
WebSocketDisconnect: If client disconnected.
"""
text = await self.receive_text()
if HAS_ORJSON and orjson is not None:
return orjson.loads(text)
return json.loads(text)
# =========================================================================
# Sending Messages
# =========================================================================
async def _send_message(self, message: Message) -> None:
"""
Internal: send message with state validation.
Args:
message: ASGI message dict.
Raises:
RuntimeError: If not connected.
"""
if self._connection_state != WebSocketState.CONNECTED:
raise RuntimeError(f"Cannot send: connection in {self._connection_state.name} state")
await self._send(message)
[docs]
async def send_text(self, data: str) -> None:
"""
Send a text message.
Args:
data: Text to send.
Raises:
RuntimeError: If not connected.
"""
await self._send_message({"type": "websocket.send", "text": data})
[docs]
async def send_bytes(self, data: bytes) -> None:
"""
Send a binary message.
Args:
data: Bytes to send.
Raises:
RuntimeError: If not connected.
"""
await self._send_message({"type": "websocket.send", "bytes": data})
[docs]
async def send_json(self, data: Any) -> None:
"""
Serialize data to JSON and send.
Args:
data: JSON-serializable data.
Raises:
RuntimeError: If not connected.
TypeError: If data not serializable.
"""
if HAS_ORJSON and orjson is not None:
text = orjson.dumps(data).decode("utf-8")
else:
text = json.dumps(data)
await self.send_text(text)
# =========================================================================
# Iteration
# =========================================================================
[docs]
async def iter_text(self) -> AsyncIterator[str]:
"""
Async iterator yielding text messages.
Yields:
Each text message until disconnect.
Raises:
TypeError: If binary message received.
"""
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
return
[docs]
async def iter_bytes(self) -> AsyncIterator[bytes]:
"""
Async iterator yielding binary messages.
Yields:
Each binary message until disconnect.
Raises:
TypeError: If text message received.
"""
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
return
def __aiter__(self) -> AsyncIterator[str]:
"""
Support async iteration (alias for iter_text).
Returns:
Async iterator over text messages.
"""
return self.iter_text()
# =========================================================================
# Typed Messages (requires genro-tytx)
# =========================================================================
[docs]
async def receive_typed(self) -> dict[str, Any]:
"""
Receive JSON with optional TYTX hydration.
Returns:
Parsed dict, hydrated if TYTX marker present.
Raises:
ImportError: If genro-tytx not installed.
RuntimeError: If not connected.
"""
if not HAS_TYTX:
raise ImportError(
"genro-tytx package required for receive_typed(). "
"Install with: pip install genro-tytx"
)
text = await self.receive_text()
result: dict[str, Any] = from_tytx(text) # Handles marker detection automatically
return result
[docs]
async def send_typed(self, data: dict[str, Any]) -> None:
"""
Serialize with TYTX and send with marker.
Args:
data: Dict with potentially typed values.
Raises:
ImportError: If genro-tytx not installed.
RuntimeError: If not connected.
"""
if not HAS_TYTX:
raise ImportError(
"genro-tytx package required for send_typed(). "
"Install with: pip install genro-tytx"
)
encoded = to_tytx(data) # Includes marker automatically
text = encoded if isinstance(encoded, str) else encoded.decode("utf-8")
await self.send_text(text)
if __name__ == "__main__":
# Minimal demo with mock ASGI interface
import asyncio
async def demo() -> None:
"""Demo WebSocket class with mock receive/send."""
# Mock scope
scope: Scope = {
"type": "websocket",
"scheme": "wss",
"path": "/ws/chat",
"query_string": b"room=general",
"headers": [
(b"host", b"example.com"),
(b"sec-websocket-protocol", b"chat, superchat"),
],
"server": ("example.com", 443),
"client": ("192.168.1.100", 54321),
"subprotocols": ["chat", "superchat"],
}
# Mock message queue
messages: list[Message] = [
{"type": "websocket.connect"},
{"type": "websocket.receive", "text": "Hello, server!"},
{"type": "websocket.receive", "text": '{"action": "ping"}'},
{"type": "websocket.disconnect", "code": 1000},
]
msg_index = 0
sent_messages: list[Message] = []
async def mock_receive() -> Message:
nonlocal msg_index
if msg_index < len(messages):
msg = messages[msg_index]
msg_index += 1
return msg
return {"type": "websocket.disconnect", "code": 1000}
async def mock_send(message: Message) -> None:
sent_messages.append(message)
# Create WebSocket instance
ws = WebSocket(scope, mock_receive, mock_send)
print(f"Path: {ws.path}")
print(f"Scheme: {ws.scheme}")
print(f"URL: {ws.url}")
print(f"Client: {ws.client}")
print(f"Subprotocols: {ws.subprotocols}")
print(f"State: {ws.connection_state.name}")
print()
# Accept connection
await ws.accept(subprotocol="chat")
print(f"Accepted with subprotocol: {ws.accepted_subprotocol}")
print(f"State after accept: {ws.connection_state.name}")
print()
# Receive messages
try:
message = await ws.receive_text()
print(f"Received text: {message}")
data = await ws.receive_json()
print(f"Received JSON: {data}")
# This will trigger disconnect
await ws.receive_text()
except WebSocketDisconnect as e:
print(f"Disconnected: code={e.code}")
print(f"State after disconnect: {ws.connection_state.name}")
print()
print(f"Sent messages: {len(sent_messages)}")
for msg in sent_messages:
print(f" {msg}")
asyncio.run(demo())