Source code for genro_asgi.websocket

# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0

"""WebSocket connection handling for ASGI applications.

Purpose
=======
This module provides the ``WebSocket`` class for handling WebSocket connections
through the ASGI interface. It wraps the low-level ASGI receive/send callables
with a Pythonic API for managing connection lifecycle, sending and receiving
messages.

WebSocket flow in ASGI::

    Client                            WebSocket Class                    ASGI Server
    ──────                            ───────────────                    ───────────
    Connect  ─────────────────────>   WebSocket(scope, receive, send)
                                      state: CONNECTING

                                      await ws.accept()  ───────────>    websocket.accept
                                      (consumes websocket.connect)
                                      state: CONNECTED

                                      await ws.receive_text()  <────────  websocket.receive
                                      await ws.send_text()  ──────────>  websocket.send

                                      await ws.close()  ──────────────>  websocket.close
                                      state: DISCONNECTED

Connection States
=================
WebSocket connections go through three states::

    CONNECTING  ─────>  CONNECTED  ─────>  DISCONNECTED
        │                   │                    │
    Initial state       After accept()      After close() or
    Before accept()     Can send/receive    client disconnect

The ``WebSocketState`` enum represents these states.

Imports Required
================
::

    from enum import IntEnum
    from typing import Any, AsyncIterator
    import json

    from .types import Scope, Receive, Send, Message
    from .datastructures import (
        Address, Headers, QueryParams, State, URL,
        headers_from_scope, query_params_from_scope
    )

Classes
=======

WebSocketState
--------------
Enum representing WebSocket connection states.

Definition::

    class WebSocketState(IntEnum):
        CONNECTING = 0    # Initial state, before accept()
        CONNECTED = 1     # After accept(), can send/receive messages
        DISCONNECTED = 2  # After close() or client disconnect

    Note: IntEnum used for easy comparisons and potential serialization.

WebSocket
---------
Main class for WebSocket connection handling.

Definition::

    class WebSocket:
        __slots__ = (
            "_scope",           # Scope: ASGI scope dict
            "_receive",         # Receive: ASGI receive callable
            "_send",            # Send: ASGI send callable
            "_connection_state", # WebSocketState: current state
            "_headers",         # Headers | None: lazy headers
            "_query_params",    # QueryParams | None: lazy query params
            "_url",             # URL | None: lazy URL object
            "_state",           # State | None: lazy user state
            "_accepted_subprotocol",  # str | None: negotiated subprotocol
        )

        def __init__(
            self,
            scope: Scope,
            receive: Receive,
            send: Send,
        ) -> None:
            '''
            Initialize WebSocket connection wrapper.

            Args:
                scope: ASGI WebSocket scope dictionary. Must have type="websocket".
                receive: ASGI receive callable for incoming messages.
                send: ASGI send callable for outgoing messages.

            Raises:
                ValueError: If scope type is not "websocket".
            '''
            if scope.get("type") != "websocket":
                raise ValueError(
                    f"Expected scope type 'websocket', got '{scope.get('type')}'"
                )
            self._scope = scope
            self._receive = receive
            self._send = send
            self._connection_state = WebSocketState.CONNECTING
            self._headers: Headers | None = None
            self._query_params: QueryParams | None = None
            self._url: URL | None = None
            self._state: State | None = None
            self._accepted_subprotocol: str | None = None

Constructor validates that scope type is "websocket". All properties are
lazy-initialized for efficiency.

Properties
==========

``scope -> Scope``
    The raw ASGI scope dictionary.

``connection_state -> WebSocketState``
    Current connection state (CONNECTING, CONNECTED, or DISCONNECTED).

``path -> str``
    The URL path component from scope. Defaults to "/" if not present.

``scheme -> str``
    URL scheme: "ws" or "wss". Derived from scope, defaults to "ws".

``url -> URL``
    Full WebSocket URL constructed from scope. Built lazily.
    URL construction rules:
    - scheme: "ws" or "wss" from scope
    - host: from server tuple or Host header
    - port: omitted if default (80 for ws, 443 for wss)
    - path: root_path + path
    - query: query_string if present

``headers -> Headers``
    Connection headers (case-insensitive). Created lazily using
    ``headers_from_scope()``.

``query_params -> QueryParams``
    Query string parameters. Created lazily using ``query_params_from_scope()``.

``state -> State``
    Per-connection state container for middleware/application data.
    Created lazily on first access.

``client -> Address | None``
    Client address (host, port) if available.

``subprotocols -> tuple[str, ...]``
    Subprotocols requested by the client (immutable tuple).

``accepted_subprotocol -> str | None``
    The subprotocol selected in accept(), or None if not set.

Methods
=======

Connection Lifecycle
--------------------

``accept()``
~~~~~~~~~~~~
Accept the WebSocket connection.

Definition::

    async def accept(
        self,
        subprotocol: str | None = None,
        headers: Mapping[str, str] | None = None,
    ) -> None:
        '''
        Accept the WebSocket connection.

        This method MUST be called before sending or receiving messages.
        It consumes the ``websocket.connect`` message and sends
        ``websocket.accept`` to complete the handshake.

        Args:
            subprotocol: Optional subprotocol to use for this connection.
                         Must be one of the client's requested subprotocols.
            headers: Optional headers to include in the accept response.
                     Dict of str:str, converted to ASGI bytes format internally.

        Raises:
            RuntimeError: If connection is not in CONNECTING state.

        Note:
            After calling accept(), connection_state becomes CONNECTED.
            The websocket.connect message is consumed internally.
        '''
        if self._connection_state != WebSocketState.CONNECTING:
            raise RuntimeError(
                f"Cannot accept: connection in {self._connection_state.name} state"
            )

        # Consume the websocket.connect message
        message = await self._receive()
        if message["type"] != "websocket.connect":
            raise RuntimeError(
                f"Expected websocket.connect, got {message['type']}"
            )

        # Build accept message
        accept_message: Message = {"type": "websocket.accept"}
        if subprotocol is not None:
            accept_message["subprotocol"] = subprotocol
            self._accepted_subprotocol = subprotocol
        if headers is not None:
            # Convert str:str to bytes:bytes for ASGI
            accept_message["headers"] = [
                (k.encode("latin-1"), v.encode("latin-1"))
                for k, v in headers.items()
            ]

        await self._send(accept_message)
        self._connection_state = WebSocketState.CONNECTED

``close()``
~~~~~~~~~~~
Close the WebSocket connection.

Definition::

    async def close(
        self,
        code: int = 1000,
        reason: str = "",
    ) -> None:
        '''
        Close the WebSocket connection.

        This method is idempotent: calling it multiple times on an already
        closed connection is safe and does nothing.

        Args:
            code: WebSocket close code. Default 1000 (normal closure).
                  Common codes:
                  - 1000: Normal closure
                  - 1001: Going away
                  - 1002: Protocol error
                  - 1003: Unsupported data type
                  - 1008: Policy violation
                  - 1011: Internal server error
            reason: Optional human-readable close reason (max 123 bytes UTF-8).

        Note:
            After calling close(), connection_state becomes DISCONNECTED.
            Calling close() when already DISCONNECTED is a no-op.
            Calling close() when CONNECTING raises RuntimeError.
        '''
        if self._connection_state == WebSocketState.DISCONNECTED:
            return  # Idempotent: already closed

        if self._connection_state == WebSocketState.CONNECTING:
            raise RuntimeError("Cannot close: connection not accepted yet")

        await self._send({
            "type": "websocket.close",
            "code": code,
            "reason": reason,
        })
        self._connection_state = WebSocketState.DISCONNECTED

Receiving Messages
------------------

``receive_text()``
~~~~~~~~~~~~~~~~~~
Receive a text message.

Definition::

    async def receive_text(self) -> str:
        '''
        Receive a text message from the WebSocket.

        Returns:
            The text message content.

        Raises:
            RuntimeError: If not in CONNECTED state.
            TypeError: If received message is bytes (use receive_bytes()).
            WebSocketDisconnect: If client disconnected.

        Note:
            STRICT mode: This method only accepts text frames. If the client
            sends binary data, TypeError is raised. This prevents silent
            encoding bugs. Use receive_bytes() for binary data.
        '''
        message = await self._receive_message()
        if "bytes" in message and message["bytes"] is not None:
            raise TypeError(
                "Received binary message. Use receive_bytes() for binary data."
            )
        return message.get("text", "")

``receive_bytes()``
~~~~~~~~~~~~~~~~~~~
Receive a binary message.

Definition::

    async def receive_bytes(self) -> bytes:
        '''
        Receive a binary message from the WebSocket.

        Returns:
            The binary message content.

        Raises:
            RuntimeError: If not in CONNECTED state.
            TypeError: If received message is text (use receive_text()).
            WebSocketDisconnect: If client disconnected.

        Note:
            STRICT mode: This method only accepts binary frames. If the client
            sends text data, TypeError is raised. Use receive_text() for text.
        '''
        message = await self._receive_message()
        if "text" in message and message["text"] is not None:
            raise TypeError(
                "Received text message. Use receive_text() for text data."
            )
        return message.get("bytes", b"")

``receive_json()``
~~~~~~~~~~~~~~~~~~
Receive and parse a JSON message.

Definition::

    async def receive_json(self) -> Any:
        '''
        Receive a text message and parse it as JSON.

        Returns:
            The parsed JSON value (dict, list, str, int, float, bool, None).

        Raises:
            RuntimeError: If not in CONNECTED state.
            TypeError: If received message is bytes.
            json.JSONDecodeError: If message is not valid JSON.
            WebSocketDisconnect: If client disconnected.

        Note:
            Uses orjson if available for faster parsing, falls back to stdlib.
            JSON decode errors are propagated directly (not wrapped).
            For typed JSON with hydration, see receive_typed() (requires genro-tytx).
        '''
        text = await self.receive_text()
        # Use orjson if available
        if HAS_ORJSON:
            return orjson.loads(text)
        return json.loads(text)

``_receive_message()``
~~~~~~~~~~~~~~~~~~~~~~
Internal method to receive the next message with state validation.

Definition::

    async def _receive_message(self) -> Message:
        '''
        Internal: receive next message with state validation.

        Returns:
            The raw ASGI message dict.

        Raises:
            RuntimeError: If not in CONNECTED state.
            WebSocketDisconnect: If websocket.disconnect received.
        '''
        if self._connection_state != WebSocketState.CONNECTED:
            raise RuntimeError(
                f"Cannot receive: connection in {self._connection_state.name} state"
            )

        message = await self._receive()
        if message["type"] == "websocket.disconnect":
            self._connection_state = WebSocketState.DISCONNECTED
            raise WebSocketDisconnect(
                code=message.get("code", 1000),
                reason=message.get("reason", ""),
            )
        return message

Sending Messages
----------------

``send_text()``
~~~~~~~~~~~~~~~
Send a text message.

Definition::

    async def send_text(self, data: str) -> None:
        '''
        Send a text message to the WebSocket.

        Args:
            data: The text message to send.

        Raises:
            RuntimeError: If not in CONNECTED state.
        '''
        await self._send_message({"type": "websocket.send", "text": data})

``send_bytes()``
~~~~~~~~~~~~~~~~
Send a binary message.

Definition::

    async def send_bytes(self, data: bytes) -> None:
        '''
        Send a binary message to the WebSocket.

        Args:
            data: The binary message to send.

        Raises:
            RuntimeError: If not in CONNECTED state.
        '''
        await self._send_message({"type": "websocket.send", "bytes": data})

``send_json()``
~~~~~~~~~~~~~~~
Send data as JSON.

Definition::

    async def send_json(self, data: Any) -> None:
        '''
        Serialize data to JSON and send as text message.

        Args:
            data: Data to serialize (must be JSON-serializable).

        Raises:
            RuntimeError: If not in CONNECTED state.
            TypeError: If data is not JSON-serializable.

        Note:
            Uses orjson if available for faster serialization, falls back to stdlib.
            For typed JSON with serialization, see send_typed() (requires genro-tytx).
        '''
        if HAS_ORJSON:
            text = orjson.dumps(data).decode("utf-8")
        else:
            text = json.dumps(data)
        await self.send_text(text)

``_send_message()``
~~~~~~~~~~~~~~~~~~~
Internal method to send a message with state validation.

Definition::

    async def _send_message(self, message: Message) -> None:
        '''
        Internal: send message with state validation.

        Args:
            message: ASGI message dict to send.

        Raises:
            RuntimeError: If not in CONNECTED state.
        '''
        if self._connection_state != WebSocketState.CONNECTED:
            raise RuntimeError(
                f"Cannot send: connection in {self._connection_state.name} state"
            )
        await self._send(message)

Iteration
---------

``iter_text()``
~~~~~~~~~~~~~~~
Async iterator for text messages.

Definition::

    async def iter_text(self) -> AsyncIterator[str]:
        '''
        Async iterator yielding text messages.

        Yields text messages until the connection closes or client disconnects.

        Yields:
            str: Each text message received.

        Raises:
            TypeError: If a binary message is received.

        Example:
            async for message in ws.iter_text():
                print(f"Received: {message}")
        '''
        try:
            while True:
                yield await self.receive_text()
        except WebSocketDisconnect:
            return

``iter_bytes()``
~~~~~~~~~~~~~~~~
Async iterator for binary messages.

Definition::

    async def iter_bytes(self) -> AsyncIterator[bytes]:
        '''
        Async iterator yielding binary messages.

        Yields binary messages until the connection closes or client disconnects.

        Yields:
            bytes: Each binary message received.

        Raises:
            TypeError: If a text message is received.

        Example:
            async for data in ws.iter_bytes():
                process_binary(data)
        '''
        try:
            while True:
                yield await self.receive_bytes()
        except WebSocketDisconnect:
            return

``__aiter__()``
~~~~~~~~~~~~~~~
Async iteration support (alias for iter_text).

Definition::

    def __aiter__(self) -> AsyncIterator[str]:
        '''
        Support async iteration over text messages.

        Equivalent to iter_text(). Use iter_bytes() for binary messages.

        Example:
            async for message in ws:  # Same as: async for message in ws.iter_text()
                handle_message(message)
        '''
        return self.iter_text()

Typed Messages (requires genro-tytx)
------------------------------------

``receive_typed()``
~~~~~~~~~~~~~~~~~~~
Receive JSON with TYTX hydration.

Definition::

    async def receive_typed(self) -> dict[str, Any]:
        '''
        Receive a text message with optional TYTX hydration.

        If the message ends with "::TYTX" marker, the content is parsed as
        JSON and hydrated using genro-tytx to restore Python types
        (Decimal, datetime, etc.).

        Returns:
            The parsed and optionally hydrated dict.

        Raises:
            RuntimeError: If not in CONNECTED state.
            ImportError: If genro-tytx is not installed.
            json.JSONDecodeError: If message is not valid JSON.
            WebSocketDisconnect: If client disconnected.

        Note:
            Requires genro-tytx package to be installed.
            Plain JSON (without ::TYTX marker) is returned as-is.

        Example:
            # Client sends: {"price": "100.50::D", "date": "2025-01-15::d"}::TYTX
            data = await ws.receive_typed()
            # data["price"] is Decimal("100.50")
            # data["date"] is date(2025, 1, 15)
        '''
        text = await self.receive_text()

        if text.endswith("::TYTX"):
            if not HAS_TYTX:
                raise ImportError(
                    "genro-tytx package required for receive_typed(). "
                    "Install with: pip install genro-tytx"
                )
            json_str = text[:-6]  # Remove "::TYTX" marker
            data = json.loads(json_str)
            return hydrate(data)  # From genro-tytx
        else:
            return json.loads(text)

``send_typed()``
~~~~~~~~~~~~~~~~
Send data with TYTX serialization.

Definition::

    async def send_typed(self, data: dict[str, Any]) -> None:
        '''
        Serialize data with TYTX and send with marker.

        Serializes Python types (Decimal, datetime, etc.) to TYTX format
        and appends "::TYTX" marker for the receiver to identify typed data.

        Args:
            data: Dict containing potentially typed values (Decimal, date, etc.).

        Raises:
            RuntimeError: If not in CONNECTED state.
            ImportError: If genro-tytx is not installed.

        Note:
            Requires genro-tytx package to be installed.

        Example:
            await ws.send_typed({
                "price": Decimal("100.50"),
                "created": datetime.now(),
            })
            # Sends: {"price": "100.50::D", "created": "2025-01-15T12:30:00::dt"}::TYTX
        '''
        if not HAS_TYTX:
            raise ImportError(
                "genro-tytx package required for send_typed(). "
                "Install with: pip install genro-tytx"
            )
        serialized = serialize(data)  # From genro-tytx
        text = json.dumps(serialized) + "::TYTX"
        await self.send_text(text)

Exception Classes
=================

WebSocketDisconnect
-------------------
Exception raised when client disconnects.

Definition::

    class WebSocketDisconnect(Exception):
        '''
        Raised when the WebSocket client disconnects.

        Attributes:
            code: WebSocket close code (default 1000).
            reason: Optional close reason string.

        Example:
            try:
                data = await ws.receive_text()
            except WebSocketDisconnect as e:
                print(f"Client disconnected: code={e.code}, reason={e.reason}")
        '''
        def __init__(self, code: int = 1000, reason: str = "") -> None:
            self.code = code
            self.reason = reason
            super().__init__(f"WebSocket disconnected: code={code}, reason={reason}")

Module Constants
================
::

    HAS_ORJSON: bool  # True if orjson is available
    HAS_TYTX: bool    # True if genro-tytx is available

    # Conditional imports
    try:
        import orjson
        HAS_ORJSON = True
    except ImportError:
        HAS_ORJSON = False

    try:
        from genro_tytx import hydrate, serialize
        HAS_TYTX = True
    except ImportError:
        HAS_TYTX = False
        hydrate = None  # type: ignore
        serialize = None  # type: ignore

Public Exports
==============
::

    __all__ = [
        "WebSocket",
        "WebSocketState",
        "WebSocketDisconnect",
    ]

Examples
========
Basic WebSocket handler::

    from genro_asgi.websocket import WebSocket, WebSocketDisconnect

    async def websocket_handler(scope, receive, send):
        ws = WebSocket(scope, receive, send)

        await ws.accept()
        print(f"Client connected from {ws.client}")

        try:
            async for message in ws:
                # Echo back
                await ws.send_text(f"You said: {message}")
        except WebSocketDisconnect:
            print("Client disconnected")

With subprotocol negotiation::

    async def handler(scope, receive, send):
        ws = WebSocket(scope, receive, send)

        # Check client's requested subprotocols
        if "graphql-ws" in ws.subprotocols:
            await ws.accept(subprotocol="graphql-ws")
        else:
            await ws.close(code=1002, reason="Unsupported protocol")
            return

        # Handle GraphQL WebSocket protocol
        ...

JSON messaging::

    async def handler(scope, receive, send):
        ws = WebSocket(scope, receive, send)
        await ws.accept()

        try:
            while True:
                data = await ws.receive_json()
                response = {"received": data, "status": "ok"}
                await ws.send_json(response)
        except WebSocketDisconnect:
            pass

With TYTX typed data::

    from decimal import Decimal
    from datetime import date

    async def handler(scope, receive, send):
        ws = WebSocket(scope, receive, send)
        await ws.accept()

        # Receive typed data
        data = await ws.receive_typed()
        # data["amount"] might be Decimal, data["date"] might be date

        # Send typed data
        await ws.send_typed({
            "total": Decimal("123.45"),
            "processed_at": date.today(),
        })

Design Decisions
================
1. **WebSocketState as IntEnum**:
   IntEnum provides both type safety and easy comparisons. Integer values
   match the conceptual progression: 0 (not connected), 1 (connected),
   2 (disconnected).

2. **STRICT receive_text()/receive_bytes()**:
   Unlike some frameworks that auto-convert between text and bytes, we
   raise TypeError if the wrong type is received. This prevents silent
   encoding bugs and makes type mismatches explicit.

3. **accept() consumes websocket.connect**:
   Per ASGI spec, a websocket.connect message may be sent before
   websocket.accept. Our accept() method handles this transparently,
   consuming the connect message before sending accept.

4. **close() is idempotent**:
   Calling close() multiple times is safe (no-op if already disconnected).
   This simplifies cleanup code and context managers.

5. **Separate iter_text()/iter_bytes()**:
   Instead of a generic iterator that returns Union[str, bytes], we provide
   separate methods for type safety. __aiter__ aliases iter_text() as the
   common case.

6. **User-friendly accept() headers**:
   The accept() method takes Mapping[str, str] for headers instead of
   ASGI's list[tuple[bytes, bytes]]. Conversion is handled internally.

7. **Lazy property initialization**:
   Headers, query_params, URL, and state are created only when accessed.
   This avoids unnecessary work for simple echo handlers.

8. **URL construction duplicated from Request**:
   Rather than creating a shared utility prematurely, URL construction logic
   is duplicated. A shared url_from_scope() utility will be added in Block 13
   during final integration if the pattern proves valuable.

9. **Typed methods require explicit import**:
   receive_typed()/send_typed() raise ImportError if genro-tytx is not
   installed, rather than silently degrading. This makes dependencies explicit.

10. **subprotocols returns tuple**:
    Returns immutable tuple[str, ...] instead of mutable list to prevent
    accidental modification of scope data.

What This Module Does NOT Include
=================================
- **Message buffering**: No internal message queue. Each receive() call
  goes directly to ASGI. Buffering can be added at application level.

- **Ping/pong handling**: ASGI servers typically handle ping/pong frames
  automatically. We don't expose these.

- **Connection timeout**: No built-in timeout for receive operations.
  Use asyncio.timeout() or similar at application level.

- **Automatic reconnection**: This is a single-connection wrapper.
  Reconnection logic belongs at application level.

- **Message validation**: No schema validation. Use pydantic or similar
  after receive_json() if needed.

References
==========
- ASGI WebSocket Spec: https://asgi.readthedocs.io/en/latest/specs/www.html#websocket
- WebSocket Close Codes: https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent/code
- RFC 6455 (WebSocket Protocol): https://tools.ietf.org/html/rfc6455
"""

from __future__ import annotations

import json
from collections.abc import Mapping
from enum import IntEnum
from typing import TYPE_CHECKING, Any, AsyncIterator

from .datastructures import (
    Address,
    Headers,
    QueryParams,
    State,
    URL,
    headers_from_scope,
    query_params_from_scope,
)
from .exceptions import WebSocketDisconnect
from .types import Message, Receive, Scope, Send

__all__ = [
    "WebSocket",
    "WebSocketState",
]

# Optional dependency: orjson for faster JSON
try:
    import orjson

    HAS_ORJSON = True
except ImportError:
    orjson = None  # type: ignore[assignment]
    HAS_ORJSON = False

# Optional dependency: genro-tytx for typed serialization
try:
    from genro_tytx import from_tytx, to_tytx

    HAS_TYTX = True
except ImportError:
    HAS_TYTX = False
    from_tytx = None  # type: ignore[assignment]
    to_tytx = None  # type: ignore[assignment]


if TYPE_CHECKING:
    pass


[docs] class WebSocketState(IntEnum): """ WebSocket connection state. Represents the three possible states of a WebSocket connection: - CONNECTING: Initial state, before accept() is called - CONNECTED: After accept(), can send and receive messages - DISCONNECTED: After close() or client disconnect Example: >>> ws = WebSocket(scope, receive, send) >>> ws.connection_state == WebSocketState.CONNECTING True >>> await ws.accept() >>> ws.connection_state == WebSocketState.CONNECTED True """ CONNECTING = 0 CONNECTED = 1 DISCONNECTED = 2
[docs] class WebSocket: """ WebSocket connection wrapper for ASGI. Provides a Pythonic interface for handling WebSocket connections through the ASGI protocol. Manages connection lifecycle (accept, close) and message sending/receiving. Attributes: scope: The ASGI scope dictionary. connection_state: Current WebSocketState. Example: >>> async def handler(scope, receive, send): ... ws = WebSocket(scope, receive, send) ... await ws.accept() ... message = await ws.receive_text() ... await ws.send_text(f"Echo: {message}") ... await ws.close() """ __slots__ = ( "_scope", "_receive", "_send", "_connection_state", "_headers", "_query_params", "_url", "_state", "_accepted_subprotocol", )
[docs] def __init__( self, scope: Scope, receive: Receive, send: Send, ) -> None: """ Initialize WebSocket connection wrapper. Args: scope: ASGI WebSocket scope dictionary. receive: ASGI receive callable. send: ASGI send callable. Raises: ValueError: If scope type is not "websocket". """ if scope.get("type") != "websocket": raise ValueError(f"Expected scope type 'websocket', got '{scope.get('type')}'") self._scope = scope self._receive = receive self._send = send self._connection_state = WebSocketState.CONNECTING self._headers: Headers | None = None self._query_params: QueryParams | None = None self._url: URL | None = None self._state: State | None = None self._accepted_subprotocol: str | None = None
# ========================================================================= # Properties # ========================================================================= @property def scope(self) -> Scope: """The raw ASGI scope dictionary.""" return self._scope @property def connection_state(self) -> WebSocketState: """Current connection state.""" return self._connection_state @property def path(self) -> str: """URL path component.""" return str(self._scope.get("path", "/")) @property def scheme(self) -> str: """URL scheme ('ws' or 'wss').""" return str(self._scope.get("scheme", "ws")) @property def url(self) -> URL: """ Full WebSocket URL. Constructed lazily from scope fields: scheme, server, root_path, path, and query_string. Returns: URL object representing the complete WebSocket URL. """ if self._url is None: scheme = self.scheme server = self._scope.get("server") path = self._scope.get("root_path", "") + self.path query_string = self._scope.get("query_string", b"") if server: host, port = server # Omit default ports if (scheme == "ws" and port == 80) or (scheme == "wss" and port == 443): netloc = host else: netloc = f"{host}:{port}" else: netloc = self.headers.get("host", "localhost") url_str = f"{scheme}://{netloc}{path}" if query_string: url_str += f"?{query_string.decode('latin-1')}" self._url = URL(url_str) return self._url @property def headers(self) -> Headers: """Connection headers (case-insensitive).""" if self._headers is None: self._headers = headers_from_scope(self._scope) return self._headers @property def query_params(self) -> QueryParams: """Query string parameters.""" if self._query_params is None: self._query_params = query_params_from_scope(self._scope) return self._query_params @property def state(self) -> State: """Per-connection state container.""" if self._state is None: self._state = State() return self._state @property def client(self) -> Address | None: """Client address (host, port) if available.""" client = self._scope.get("client") if client: return Address(client[0], client[1]) return None @property def subprotocols(self) -> tuple[str, ...]: """Subprotocols requested by the client (immutable).""" return tuple(self._scope.get("subprotocols", [])) @property def accepted_subprotocol(self) -> str | None: """The subprotocol selected in accept(), or None.""" return self._accepted_subprotocol # ========================================================================= # Connection Lifecycle # =========================================================================
[docs] async def accept( self, subprotocol: str | None = None, headers: Mapping[str, str] | None = None, ) -> None: """ Accept the WebSocket connection. Must be called before sending or receiving messages. Consumes the ``websocket.connect`` message and sends ``websocket.accept``. Args: subprotocol: Optional subprotocol to negotiate. headers: Optional headers to include in accept response. Raises: RuntimeError: If not in CONNECTING state or unexpected message. """ if self._connection_state != WebSocketState.CONNECTING: raise RuntimeError(f"Cannot accept: connection in {self._connection_state.name} state") # Consume the websocket.connect message message = await self._receive() if message["type"] != "websocket.connect": raise RuntimeError(f"Expected websocket.connect, got {message['type']}") # Build accept message accept_message: Message = {"type": "websocket.accept"} if subprotocol is not None: accept_message["subprotocol"] = subprotocol self._accepted_subprotocol = subprotocol if headers is not None: accept_message["headers"] = [ (k.encode("latin-1"), v.encode("latin-1")) for k, v in headers.items() ] await self._send(accept_message) self._connection_state = WebSocketState.CONNECTED
[docs] async def close( self, code: int = 1000, reason: str = "", ) -> None: """ Close the WebSocket connection. Idempotent: safe to call multiple times. Args: code: WebSocket close code (default 1000). reason: Optional close reason. Raises: RuntimeError: If connection not accepted yet. """ if self._connection_state == WebSocketState.DISCONNECTED: return # Idempotent if self._connection_state == WebSocketState.CONNECTING: raise RuntimeError("Cannot close: connection not accepted yet") await self._send( { "type": "websocket.close", "code": code, "reason": reason, } ) self._connection_state = WebSocketState.DISCONNECTED
# ========================================================================= # Receiving Messages # ========================================================================= async def _receive_message(self) -> Message: """ Internal: receive message with state validation. Returns: The ASGI message dict. Raises: RuntimeError: If not connected. WebSocketDisconnect: If client disconnected. """ if self._connection_state != WebSocketState.CONNECTED: raise RuntimeError(f"Cannot receive: connection in {self._connection_state.name} state") message = await self._receive() if message["type"] == "websocket.disconnect": self._connection_state = WebSocketState.DISCONNECTED raise WebSocketDisconnect( code=message.get("code", 1000), reason=message.get("reason", ""), ) return message
[docs] async def receive_text(self) -> str: """ Receive a text message. Returns: The text message content. Raises: RuntimeError: If not connected. TypeError: If received binary data. WebSocketDisconnect: If client disconnected. """ message = await self._receive_message() if "bytes" in message and message["bytes"] is not None: raise TypeError("Received binary message. Use receive_bytes() for binary data.") text: str = message.get("text", "") return text
[docs] async def receive_bytes(self) -> bytes: """ Receive a binary message. Returns: The binary message content. Raises: RuntimeError: If not connected. TypeError: If received text data. WebSocketDisconnect: If client disconnected. """ message = await self._receive_message() if "text" in message and message["text"] is not None: raise TypeError("Received text message. Use receive_text() for text data.") data: bytes = message.get("bytes", b"") return data
[docs] async def receive_json(self) -> Any: """ Receive and parse a JSON text message. Returns: The parsed JSON value. Raises: RuntimeError: If not connected. TypeError: If received binary data. json.JSONDecodeError: If not valid JSON. WebSocketDisconnect: If client disconnected. """ text = await self.receive_text() if HAS_ORJSON and orjson is not None: return orjson.loads(text) return json.loads(text)
# ========================================================================= # Sending Messages # ========================================================================= async def _send_message(self, message: Message) -> None: """ Internal: send message with state validation. Args: message: ASGI message dict. Raises: RuntimeError: If not connected. """ if self._connection_state != WebSocketState.CONNECTED: raise RuntimeError(f"Cannot send: connection in {self._connection_state.name} state") await self._send(message)
[docs] async def send_text(self, data: str) -> None: """ Send a text message. Args: data: Text to send. Raises: RuntimeError: If not connected. """ await self._send_message({"type": "websocket.send", "text": data})
[docs] async def send_bytes(self, data: bytes) -> None: """ Send a binary message. Args: data: Bytes to send. Raises: RuntimeError: If not connected. """ await self._send_message({"type": "websocket.send", "bytes": data})
[docs] async def send_json(self, data: Any) -> None: """ Serialize data to JSON and send. Args: data: JSON-serializable data. Raises: RuntimeError: If not connected. TypeError: If data not serializable. """ if HAS_ORJSON and orjson is not None: text = orjson.dumps(data).decode("utf-8") else: text = json.dumps(data) await self.send_text(text)
# ========================================================================= # Iteration # =========================================================================
[docs] async def iter_text(self) -> AsyncIterator[str]: """ Async iterator yielding text messages. Yields: Each text message until disconnect. Raises: TypeError: If binary message received. """ try: while True: yield await self.receive_text() except WebSocketDisconnect: return
[docs] async def iter_bytes(self) -> AsyncIterator[bytes]: """ Async iterator yielding binary messages. Yields: Each binary message until disconnect. Raises: TypeError: If text message received. """ try: while True: yield await self.receive_bytes() except WebSocketDisconnect: return
def __aiter__(self) -> AsyncIterator[str]: """ Support async iteration (alias for iter_text). Returns: Async iterator over text messages. """ return self.iter_text() # ========================================================================= # Typed Messages (requires genro-tytx) # =========================================================================
[docs] async def receive_typed(self) -> dict[str, Any]: """ Receive JSON with optional TYTX hydration. Returns: Parsed dict, hydrated if TYTX marker present. Raises: ImportError: If genro-tytx not installed. RuntimeError: If not connected. """ if not HAS_TYTX: raise ImportError( "genro-tytx package required for receive_typed(). " "Install with: pip install genro-tytx" ) text = await self.receive_text() result: dict[str, Any] = from_tytx(text) # Handles marker detection automatically return result
[docs] async def send_typed(self, data: dict[str, Any]) -> None: """ Serialize with TYTX and send with marker. Args: data: Dict with potentially typed values. Raises: ImportError: If genro-tytx not installed. RuntimeError: If not connected. """ if not HAS_TYTX: raise ImportError( "genro-tytx package required for send_typed(). " "Install with: pip install genro-tytx" ) encoded = to_tytx(data) # Includes marker automatically text = encoded if isinstance(encoded, str) else encoded.decode("utf-8") await self.send_text(text)
if __name__ == "__main__": # Minimal demo with mock ASGI interface import asyncio async def demo() -> None: """Demo WebSocket class with mock receive/send.""" # Mock scope scope: Scope = { "type": "websocket", "scheme": "wss", "path": "/ws/chat", "query_string": b"room=general", "headers": [ (b"host", b"example.com"), (b"sec-websocket-protocol", b"chat, superchat"), ], "server": ("example.com", 443), "client": ("192.168.1.100", 54321), "subprotocols": ["chat", "superchat"], } # Mock message queue messages: list[Message] = [ {"type": "websocket.connect"}, {"type": "websocket.receive", "text": "Hello, server!"}, {"type": "websocket.receive", "text": '{"action": "ping"}'}, {"type": "websocket.disconnect", "code": 1000}, ] msg_index = 0 sent_messages: list[Message] = [] async def mock_receive() -> Message: nonlocal msg_index if msg_index < len(messages): msg = messages[msg_index] msg_index += 1 return msg return {"type": "websocket.disconnect", "code": 1000} async def mock_send(message: Message) -> None: sent_messages.append(message) # Create WebSocket instance ws = WebSocket(scope, mock_receive, mock_send) print(f"Path: {ws.path}") print(f"Scheme: {ws.scheme}") print(f"URL: {ws.url}") print(f"Client: {ws.client}") print(f"Subprotocols: {ws.subprotocols}") print(f"State: {ws.connection_state.name}") print() # Accept connection await ws.accept(subprotocol="chat") print(f"Accepted with subprotocol: {ws.accepted_subprotocol}") print(f"State after accept: {ws.connection_state.name}") print() # Receive messages try: message = await ws.receive_text() print(f"Received text: {message}") data = await ws.receive_json() print(f"Received JSON: {data}") # This will trigger disconnect await ws.receive_text() except WebSocketDisconnect as e: print(f"Disconnected: code={e.code}") print(f"State after disconnect: {ws.connection_state.name}") print() print(f"Sent messages: {len(sent_messages)}") for msg in sent_messages: print(f" {msg}") asyncio.run(demo())