Source code for genro_asgi.websocket

# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0

"""WebSocket connection handling for ASGI applications.

Classes:
    WebSocketState — IntEnum: CONNECTING(0), CONNECTED(1), DISCONNECTED(2)
    WebSocket — wraps ASGI receive/send with Pythonic API

Flow: WebSocket(scope, receive, send) → accept() → send/receive → close()

Strict typing: receive_text() rejects bytes, receive_bytes() rejects text.
Optional TYTX support: receive_typed()/send_typed() for typed serialization.
"""

from __future__ import annotations

import json
from collections.abc import Mapping
from enum import IntEnum
from typing import TYPE_CHECKING, Any, AsyncIterator

from .datastructures import (
    Address,
    Headers,
    QueryParams,
    State,
    URL,
    headers_from_scope,
    query_params_from_scope,
)
from .exceptions import WebSocketDisconnect
from .types import Message, Receive, Scope, Send

__all__ = [
    "WebSocket",
    "WebSocketState",
]

# Optional dependency: orjson for faster JSON
try:
    import orjson

    HAS_ORJSON = True
except ImportError:
    orjson = None  # type: ignore[assignment]
    HAS_ORJSON = False

# Optional dependency: genro-tytx for typed serialization
try:
    from genro_tytx import from_tytx, to_tytx

    HAS_TYTX = True
except ImportError:
    HAS_TYTX = False
    from_tytx = None  # type: ignore[assignment]
    to_tytx = None  # type: ignore[assignment]


if TYPE_CHECKING:
    pass


[docs] class WebSocketState(IntEnum): """ WebSocket connection state. Represents the three possible states of a WebSocket connection: - CONNECTING: Initial state, before accept() is called - CONNECTED: After accept(), can send and receive messages - DISCONNECTED: After close() or client disconnect Example: >>> ws = WebSocket(scope, receive, send) >>> ws.connection_state == WebSocketState.CONNECTING True >>> await ws.accept() >>> ws.connection_state == WebSocketState.CONNECTED True """ CONNECTING = 0 CONNECTED = 1 DISCONNECTED = 2
[docs] class WebSocket: """ WebSocket connection wrapper for ASGI. Provides a Pythonic interface for handling WebSocket connections through the ASGI protocol. Manages connection lifecycle (accept, close) and message sending/receiving. Attributes: scope: The ASGI scope dictionary. connection_state: Current WebSocketState. Example: >>> async def handler(scope, receive, send): ... ws = WebSocket(scope, receive, send) ... await ws.accept() ... message = await ws.receive_text() ... await ws.send_text(f"Echo: {message}") ... await ws.close() """ __slots__ = ( "_scope", "_receive", "_send", "_connection_state", "_headers", "_query_params", "_url", "_state", "_accepted_subprotocol", )
[docs] def __init__( self, scope: Scope, receive: Receive, send: Send, ) -> None: """ Initialize WebSocket connection wrapper. Args: scope: ASGI WebSocket scope dictionary. receive: ASGI receive callable. send: ASGI send callable. Raises: ValueError: If scope type is not "websocket". """ if scope.get("type") != "websocket": raise ValueError(f"Expected scope type 'websocket', got '{scope.get('type')}'") self._scope = scope self._receive = receive self._send = send self._connection_state = WebSocketState.CONNECTING self._headers: Headers | None = None self._query_params: QueryParams | None = None self._url: URL | None = None self._state: State | None = None self._accepted_subprotocol: str | None = None
# ========================================================================= # Properties # ========================================================================= @property def scope(self) -> Scope: """The raw ASGI scope dictionary.""" return self._scope @property def asgi_receive(self) -> Receive: """The ASGI receive callable.""" return self._receive @property def asgi_send(self) -> Send: """The ASGI send callable.""" return self._send @property def connection_state(self) -> WebSocketState: """Current connection state.""" return self._connection_state @property def path(self) -> str: """URL path component.""" return str(self._scope.get("path", "/")) @property def scheme(self) -> str: """URL scheme ('ws' or 'wss').""" return str(self._scope.get("scheme", "ws")) @property def url(self) -> URL: """ Full WebSocket URL. Constructed lazily from scope fields: scheme, server, root_path, path, and query_string. Returns: URL object representing the complete WebSocket URL. """ if self._url is None: scheme = self.scheme server = self._scope.get("server") path = self._scope.get("root_path", "") + self.path query_string = self._scope.get("query_string", b"") if server: host, port = server # Omit default ports if (scheme == "ws" and port == 80) or (scheme == "wss" and port == 443): netloc = host else: netloc = f"{host}:{port}" else: netloc = self.headers.get("host", "localhost") url_str = f"{scheme}://{netloc}{path}" if query_string: url_str += f"?{query_string.decode('latin-1')}" self._url = URL(url_str) return self._url @property def headers(self) -> Headers: """Connection headers (case-insensitive).""" if self._headers is None: self._headers = headers_from_scope(self._scope) return self._headers @property def query_params(self) -> QueryParams: """Query string parameters.""" if self._query_params is None: self._query_params = query_params_from_scope(self._scope) return self._query_params @property def state(self) -> State: """Per-connection state container.""" if self._state is None: self._state = State() return self._state @property def client(self) -> Address | None: """Client address (host, port) if available.""" client = self._scope.get("client") if client: return Address(client[0], client[1]) return None @property def subprotocols(self) -> tuple[str, ...]: """Subprotocols requested by the client (immutable).""" return tuple(self._scope.get("subprotocols", [])) @property def accepted_subprotocol(self) -> str | None: """The subprotocol selected in accept(), or None.""" return self._accepted_subprotocol # ========================================================================= # Connection Lifecycle # =========================================================================
[docs] async def accept( self, subprotocol: str | None = None, headers: Mapping[str, str] | None = None, ) -> None: """ Accept the WebSocket connection. Must be called before sending or receiving messages. Consumes the ``websocket.connect`` message and sends ``websocket.accept``. Args: subprotocol: Optional subprotocol to negotiate. headers: Optional headers to include in accept response. Raises: RuntimeError: If not in CONNECTING state or unexpected message. """ if self._connection_state != WebSocketState.CONNECTING: raise RuntimeError(f"Cannot accept: connection in {self._connection_state.name} state") # Consume the websocket.connect message message = await self._receive() if message["type"] != "websocket.connect": raise RuntimeError(f"Expected websocket.connect, got {message['type']}") # Build accept message accept_message: Message = {"type": "websocket.accept"} if subprotocol is not None: accept_message["subprotocol"] = subprotocol self._accepted_subprotocol = subprotocol if headers is not None: accept_message["headers"] = [ (k.encode("latin-1"), v.encode("latin-1")) for k, v in headers.items() ] await self._send(accept_message) self._connection_state = WebSocketState.CONNECTED
[docs] async def close( self, code: int = 1000, reason: str = "", ) -> None: """ Close the WebSocket connection. Idempotent: safe to call multiple times. Args: code: WebSocket close code (default 1000). reason: Optional close reason. Raises: RuntimeError: If connection not accepted yet. """ if self._connection_state == WebSocketState.DISCONNECTED: return # Idempotent if self._connection_state == WebSocketState.CONNECTING: raise RuntimeError("Cannot close: connection not accepted yet") await self._send( { "type": "websocket.close", "code": code, "reason": reason, } ) self._connection_state = WebSocketState.DISCONNECTED
# ========================================================================= # Receiving Messages # ========================================================================= async def _receive_message(self) -> Message: """ Internal: receive message with state validation. Returns: The ASGI message dict. Raises: RuntimeError: If not connected. WebSocketDisconnect: If client disconnected. """ if self._connection_state != WebSocketState.CONNECTED: raise RuntimeError(f"Cannot receive: connection in {self._connection_state.name} state") message = await self._receive() if message["type"] == "websocket.disconnect": self._connection_state = WebSocketState.DISCONNECTED raise WebSocketDisconnect( code=message.get("code", 1000), reason=message.get("reason", ""), ) return message
[docs] async def receive_text(self) -> str: """ Receive a text message. Returns: The text message content. Raises: RuntimeError: If not connected. TypeError: If received binary data. WebSocketDisconnect: If client disconnected. """ message = await self._receive_message() if "bytes" in message and message["bytes"] is not None: raise TypeError("Received binary message. Use receive_bytes() for binary data.") text: str = message.get("text", "") return text
[docs] async def receive_bytes(self) -> bytes: """ Receive a binary message. Returns: The binary message content. Raises: RuntimeError: If not connected. TypeError: If received text data. WebSocketDisconnect: If client disconnected. """ message = await self._receive_message() if "text" in message and message["text"] is not None: raise TypeError("Received text message. Use receive_text() for text data.") data: bytes = message.get("bytes", b"") return data
[docs] async def receive_json(self) -> Any: """ Receive and parse a JSON text message. Returns: The parsed JSON value. Raises: RuntimeError: If not connected. TypeError: If received binary data. json.JSONDecodeError: If not valid JSON. WebSocketDisconnect: If client disconnected. """ text = await self.receive_text() if HAS_ORJSON and orjson is not None: return orjson.loads(text) return json.loads(text)
# ========================================================================= # Sending Messages # ========================================================================= async def _send_message(self, message: Message) -> None: """ Internal: send message with state validation. Args: message: ASGI message dict. Raises: RuntimeError: If not connected. """ if self._connection_state != WebSocketState.CONNECTED: raise RuntimeError(f"Cannot send: connection in {self._connection_state.name} state") await self._send(message)
[docs] async def send_text(self, data: str) -> None: """ Send a text message. Args: data: Text to send. Raises: RuntimeError: If not connected. """ await self._send_message({"type": "websocket.send", "text": data})
[docs] async def send_bytes(self, data: bytes) -> None: """ Send a binary message. Args: data: Bytes to send. Raises: RuntimeError: If not connected. """ await self._send_message({"type": "websocket.send", "bytes": data})
[docs] async def send_json(self, data: Any) -> None: """ Serialize data to JSON and send. Args: data: JSON-serializable data. Raises: RuntimeError: If not connected. TypeError: If data not serializable. """ if HAS_ORJSON and orjson is not None: text = orjson.dumps(data).decode("utf-8") else: text = json.dumps(data) await self.send_text(text)
# ========================================================================= # Iteration # =========================================================================
[docs] async def iter_text(self) -> AsyncIterator[str]: """ Async iterator yielding text messages. Yields: Each text message until disconnect. Raises: TypeError: If binary message received. """ try: while True: yield await self.receive_text() except WebSocketDisconnect: return
[docs] async def iter_bytes(self) -> AsyncIterator[bytes]: """ Async iterator yielding binary messages. Yields: Each binary message until disconnect. Raises: TypeError: If text message received. """ try: while True: yield await self.receive_bytes() except WebSocketDisconnect: return
def __aiter__(self) -> AsyncIterator[str]: """ Support async iteration (alias for iter_text). Returns: Async iterator over text messages. """ return self.iter_text() # ========================================================================= # Typed Messages (requires genro-tytx) # =========================================================================
[docs] async def receive_typed(self) -> dict[str, Any]: """ Receive JSON with optional TYTX hydration. Returns: Parsed dict, hydrated if TYTX marker present. Raises: ImportError: If genro-tytx not installed. RuntimeError: If not connected. """ if not HAS_TYTX: raise ImportError( "genro-tytx package required for receive_typed(). " "Install with: pip install genro-tytx" ) text = await self.receive_text() result: dict[str, Any] = from_tytx(text) # Handles marker detection automatically return result
[docs] async def send_typed(self, data: dict[str, Any]) -> None: """ Serialize with TYTX and send with marker. Args: data: Dict with potentially typed values. Raises: ImportError: If genro-tytx not installed. RuntimeError: If not connected. """ if not HAS_TYTX: raise ImportError( "genro-tytx package required for send_typed(). " "Install with: pip install genro-tytx" ) encoded = to_tytx(data) # Includes marker automatically text = encoded if isinstance(encoded, str) else encoded.decode("utf-8") await self.send_text(text)
if __name__ == "__main__": # Minimal demo with mock ASGI interface import asyncio async def demo() -> None: """Demo WebSocket class with mock receive/send.""" # Mock scope scope: Scope = { "type": "websocket", "scheme": "wss", "path": "/ws/chat", "query_string": b"room=general", "headers": [ (b"host", b"example.com"), (b"sec-websocket-protocol", b"chat, superchat"), ], "server": ("example.com", 443), "client": ("192.168.1.100", 54321), "subprotocols": ["chat", "superchat"], } # Mock message queue messages: list[Message] = [ {"type": "websocket.connect"}, {"type": "websocket.receive", "text": "Hello, server!"}, {"type": "websocket.receive", "text": '{"action": "ping"}'}, {"type": "websocket.disconnect", "code": 1000}, ] msg_index = 0 sent_messages: list[Message] = [] async def mock_receive() -> Message: nonlocal msg_index if msg_index < len(messages): msg = messages[msg_index] msg_index += 1 return msg return {"type": "websocket.disconnect", "code": 1000} async def mock_send(message: Message) -> None: sent_messages.append(message) # Create WebSocket instance ws = WebSocket(scope, mock_receive, mock_send) print(f"Path: {ws.path}") print(f"Scheme: {ws.scheme}") print(f"URL: {ws.url}") print(f"Client: {ws.client}") print(f"Subprotocols: {ws.subprotocols}") print(f"State: {ws.connection_state.name}") print() # Accept connection await ws.accept(subprotocol="chat") print(f"Accepted with subprotocol: {ws.accepted_subprotocol}") print(f"State after accept: {ws.connection_state.name}") print() # Receive messages try: message = await ws.receive_text() print(f"Received text: {message}") data = await ws.receive_json() print(f"Received JSON: {data}") # This will trigger disconnect await ws.receive_text() except WebSocketDisconnect as e: print(f"Disconnected: code={e.code}") print(f"State after disconnect: {ws.connection_state.name}") print() print(f"Sent messages: {len(sent_messages)}") for msg in sent_messages: print(f" {msg}") asyncio.run(demo())