Source code for genro_asgi.request

# Copyright 2025 Softwell S.r.l.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Transport-agnostic request system.

This module provides the complete request handling infrastructure:
- BaseRequest: Abstract interface for all request types
- HttpRequest: ASGI HTTP request adapter
- MsgRequest: Message-based request adapter (WSX over WebSocket, NATS)
- RequestRegistry: Factory and tracking for active requests

Architecture:
    BaseRequest (ABC)
        ├── HttpRequest     # ASGI HTTP scope
        └── MsgRequest      # WSX message (WebSocket, NATS)

Every request:
1. Gets a unique `id` (correlation ID)
2. Is registered in RequestRegistry
3. Has `app_name` for per-app metrics
4. Has `created_at` for age tracking
5. Is unregistered on completion

Example:
    registry = RequestRegistry()
    request = await registry.create(scope, receive, send)
    try:
        result = await handler(request)
    finally:
        registry.unregister(request.id)
"""

from __future__ import annotations

import json as stdlib_json
import time
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any

from .datastructures import (
    Address,
    Headers,
    QueryParams,
    State,
    URL,
    headers_from_scope,
    query_params_from_scope,
)
from .types import Receive, Scope, Send

if TYPE_CHECKING:
    from .websocket import WebSocket

__all__ = [
    "BaseRequest",
    "HttpRequest",
    "MsgRequest",
    "RequestRegistry",
    "REQUEST_FACTORIES",
    "get_current_request",
    "set_current_request",
]

# ContextVar for current request - allows any code to access the current request
_current_request: ContextVar["BaseRequest | None"] = ContextVar("current_request", default=None)


[docs] def get_current_request() -> "BaseRequest | None": """Get the current request from context. Returns None if not in request context.""" return _current_request.get()
[docs] def set_current_request(request: "BaseRequest | None") -> Any: """Set the current request in context. Returns token for reset.""" return _current_request.set(request)
[docs] class BaseRequest(ABC): """ Abstract base class for transport-agnostic requests. All request implementations (HTTP, Message-based) must implement this interface, allowing handlers to work uniformly across transports. Properties: id: Server-generated correlation ID (internal) external_id: Client-provided ID for correlation (optional) method: HTTP method (GET, POST, PUT, DELETE, PATCH) path: Request path (e.g., '/users/42') headers: Request headers as dict cookies: Request cookies as dict query: Query parameters data: Request body/payload transport: Transport type ('http', 'websocket', 'nats') app_name: Name of the app handling this request (for metrics) created_at: Timestamp when request was created (for age tracking) tytx_mode: True if request uses TYTX serialization tytx_transport: TYTX transport type ('json', 'msgpack') or None """ __slots__ = ( "_app_name", "_auth_tags", "_created_at", "_env_capabilities", "_external_id", "_tytx_mode", "_tytx_transport", "response", ) def __init__(self) -> None: from .response import Response self._app_name: str | None = None self._auth_tags: list[str] = [] self._created_at: float = time.time() self._env_capabilities: list[str] = [] self._external_id: str | None = None self._tytx_mode: bool = False self._tytx_transport: str | None = None self.response: Response = Response(request=self) @property @abstractmethod def id(self) -> str: """Correlation ID for request/response matching.""" @property @abstractmethod def method(self) -> str: """HTTP method: GET, POST, PUT, DELETE, PATCH.""" @property @abstractmethod def path(self) -> str: """Request path (e.g., '/users/42').""" @property @abstractmethod def headers(self) -> dict[str, str]: """Request headers (lowercase keys).""" @property @abstractmethod def cookies(self) -> dict[str, str]: """Request cookies.""" @property @abstractmethod def query(self) -> dict[str, Any]: """Query parameters.""" @property @abstractmethod def data(self) -> Any: """Request body/payload.""" @property @abstractmethod def transport(self) -> str: """Transport type: 'http', 'websocket', 'nats'.""" @property def auth_tags(self) -> list[str]: """Auth tags (set from scope during init by AuthMiddleware).""" return self._auth_tags @property def env_capabilities(self) -> list[str]: """Environment capabilities (set from scope during init).""" return self._env_capabilities @property def external_id(self) -> str | None: """Client-provided ID for correlation (e.g., WSX message id).""" return self._external_id @external_id.setter def external_id(self, value: str | None) -> None: self._external_id = value @property def tytx_mode(self) -> bool: """True if request uses TYTX serialization.""" return self._tytx_mode @tytx_mode.setter def tytx_mode(self, value: bool) -> None: self._tytx_mode = value @property def tytx_transport(self) -> str | None: """TYTX transport type ('json', 'msgpack') or None.""" return self._tytx_transport @tytx_transport.setter def tytx_transport(self, value: str | None) -> None: self._tytx_transport = value @property def app_name(self) -> str | None: """Name of the app handling this request (set after routing).""" return self._app_name @app_name.setter def app_name(self, value: str | None) -> None: self._app_name = value @property def created_at(self) -> float: """Timestamp when request was created.""" return self._created_at @property def age(self) -> float: """Seconds since request was created.""" return time.time() - self._created_at
[docs] @abstractmethod async def init( self, scope: Scope, receive: Receive, send: Send | None = None, **kwargs: Any, ) -> None: """Async initialization - subclasses must override."""
def __repr__(self) -> str: return ( f"<{self.__class__.__name__} " f"id={self.id!r} method={self.method} path={self.path!r} " f"transport={self.transport}>" )
[docs] class HttpRequest(BaseRequest): """HTTP request adapter wrapping ASGI scope.""" __slots__ = ( "_scope", "_body", "_headers", "_headers_obj", "_cookies", "_query", "_query_obj", "_data", "_id", "_url", "_state", ) def __init__(self) -> None: super().__init__() # Slots initialized to None, populated by init() self._scope: Scope = {} self._body: bytes = b"" self._headers: dict[str, str] = {} self._cookies: dict[str, str] = {} self._query: dict[str, Any] = {} self._data: Any = None self._id: str = "" self._url: URL | None = None self._state: State | None = None self._headers_obj: Headers | None = None self._query_obj: QueryParams | None = None
[docs] async def init( self, scope: Scope, receive: Receive, send: Send | None = None, **kwargs: Any, ) -> None: """Async initialization - reads body and parses request data.""" from genro_tytx import asgi_data self._scope = scope # Parse headers first (needed for TYTX detection) self._headers = {} for name, value in scope.get("headers", []): self._headers[name.decode("latin-1").lower()] = value.decode("latin-1") # Check for TYTX mode via X-TYTX-Transport header tytx_transport = self._headers.get("x-tytx-transport") if tytx_transport: self._tytx_mode = True self._tytx_transport = tytx_transport.lower() # Use asgi_data for parsing (handles both TYTX and normal requests) data = await asgi_data(dict(scope), receive) self._body = b"" self._headers = data.get("headers", self._headers) self._cookies = data.get("cookies", {}) self._query = data.get("query", {}) self._data = data.get("body") # Generate or extract request ID self._id = self._headers.get("x-request-id", str(uuid.uuid4())) self._external_id = self._headers.get("x-external-id") # Set auth_tags and env_capabilities from scope (set by middleware) self._auth_tags = list(scope.get("auth_tags", [])) self._env_capabilities = list(scope.get("env_capabilities", []))
@property def id(self) -> str: return self._id @property def method(self) -> str: return str(self._scope.get("method", "GET")).upper() @property def path(self) -> str: return str(self._scope.get("path", "/")) @property def headers(self) -> dict[str, str]: return self._headers @property def cookies(self) -> dict[str, str]: return self._cookies @property def query(self) -> dict[str, Any]: return self._query @property def data(self) -> Any: return self._data @property def transport(self) -> str: return "http" @property def scope(self) -> Scope: """Raw ASGI scope dict.""" return self._scope @property def body(self) -> bytes: """Raw body bytes.""" return self._body @property def scheme(self) -> str: """URL scheme: http or https.""" return str(self._scope.get("scheme", "http")) @property def url(self) -> URL: """Full request URL.""" if self._url is None: scheme = self.scheme server = self._scope.get("server") path = self._scope.get("root_path", "") + self.path query_string = self._scope.get("query_string", b"") if server: host, port = server if (scheme == "http" and port == 80) or (scheme == "https" and port == 443): netloc = host else: netloc = f"{host}:{port}" else: netloc = self._headers.get("host", "localhost") url_str = f"{scheme}://{netloc}{path}" if query_string: url_str += f"?{query_string.decode('latin-1')}" self._url = URL(url_str) return self._url @property def headers_obj(self) -> Headers: """Request headers as Headers object (case-insensitive).""" if self._headers_obj is None: self._headers_obj = headers_from_scope(self._scope) return self._headers_obj @property def query_params(self) -> QueryParams: """Query string parameters as QueryParams object.""" if self._query_obj is None: self._query_obj = query_params_from_scope(self._scope) return self._query_obj @property def client(self) -> Address | None: """Client address (host, port) if available.""" client = self._scope.get("client") if client: return Address(host=client[0], port=client[1]) return None @property def state(self) -> State: """Request-scoped state container.""" if self._state is None: self._state = State() return self._state @property def content_type(self) -> str | None: """Content-Type header value.""" return self._headers.get("content-type")
[docs] class MsgRequest(BaseRequest): """ Message-based request adapter (WSX over WebSocket, NATS, etc.). Parses WSX:// formatted messages into BaseRequest interface. Transport-agnostic: works with any message-based protocol. """ __slots__ = ( "_scope", "_send", "_id", "_method", "_path", "_headers", "_cookies", "_query", "_data", "_transport_type", "_websocket", ) def __init__(self) -> None: super().__init__() # Slots initialized to defaults, populated by init() self._scope: Scope = {} self._send: Send | None = None self._id: str = "" self._method: str = "GET" self._path: str = "/" self._headers: dict[str, str] = {} self._cookies: dict[str, str] = {} self._query: dict[str, Any] = {} self._data: Any = None self._transport_type: str = "websocket" self._websocket: "WebSocket | None" = None
[docs] async def init( self, scope: Scope, receive: Receive, send: Send | None = None, **kwargs: Any, ) -> None: """Async initialization - parses WSX message.""" self._scope = scope self._send = send self._transport_type = kwargs.get("transport_type", "websocket") self._websocket = kwargs.get("websocket") # Get message from kwargs message = kwargs.get("message") if message is None: raise ValueError("MsgRequest requires 'message' kwarg") # Parse WSX message parsed = self._parse_wsx_message(message) # Required fields if "id" not in parsed: raise ValueError("WSX message missing required 'id' field") if "method" not in parsed: raise ValueError("WSX message missing required 'method' field") # The WSX message 'id' is the client's external_id self._external_id = parsed["id"] # Generate internal server id self._id = str(uuid.uuid4()) self._method = parsed["method"].upper() self._path = parsed.get("path", "/") self._headers = parsed.get("headers", {}) self._cookies = parsed.get("cookies", {}) self._query = parsed.get("query", {}) self._data = parsed.get("data") # Detect TYTX mode from message marker or header self._tytx_mode = ( parsed.get("tytx", False) or "tytx" in self._headers.get("content-type", "").lower() ) # Set auth_tags and env_capabilities from scope (set by middleware) self._auth_tags = list(scope.get("auth_tags", [])) self._env_capabilities = list(scope.get("env_capabilities", []))
def _parse_wsx_message(self, data: str | bytes) -> dict[str, Any]: """Parse WSX protocol message into request dict. Handles both text (JSON) and binary (msgpack) WSX messages. Supports TYTX hydration for type-aware data reconstruction. Args: data: Raw message data, either str (JSON) or bytes (msgpack). Returns: Parsed message dict with keys: id, method, path, headers, cookies, query, data. Note: Format detection: - bytes: Attempts msgpack parsing via genro_tytx - str starting with "WSX://": Protocol prefix stripped - str ending with "::JS": TYTX JSON with type markers - Other str: Standard JSON parsing If genro_tytx is not installed, falls back to stdlib json. """ if isinstance(data, bytes): # Binary data - try msgpack via from_tytx try: from genro_tytx import from_tytx return dict(from_tytx(data, transport="msgpack")) except ImportError: data = data.decode("utf-8") # String data if data.startswith("WSX://"): data = data[6:] # Check for TYTX JSON marker if data.endswith("::JS"): try: from genro_tytx import from_tytx return dict(from_tytx(data)) except ImportError: data = data[:-4] # Strip marker, parse as regular JSON return dict(stdlib_json.loads(data)) @property def id(self) -> str: return self._id @property def method(self) -> str: return self._method @property def path(self) -> str: return self._path @property def headers(self) -> dict[str, str]: return self._headers @property def cookies(self) -> dict[str, str]: return self._cookies @property def query(self) -> dict[str, Any]: return self._query @property def data(self) -> Any: return self._data @property def transport(self) -> str: return self._transport_type @property def scope(self) -> Scope: """Access to raw ASGI scope.""" return self._scope @property def websocket(self) -> "WebSocket | None": """Access to underlying WebSocket connection (if available).""" return self._websocket @property def client(self) -> tuple[str, int] | None: """Client address as (host, port) tuple.""" return self._scope.get("client")
[docs] class RequestRegistry: """ Registry for creating and tracking active requests. Responsibilities: - Creates appropriate request based on scope["type"] using factories dict - Calls async init() on the created request - Tracks active requests for monitoring and metrics - Provides iteration and lookup by request ID Example: registry = RequestRegistry() request = await registry.create(scope, receive, send) print(f"Active: {len(registry)}") registry.unregister() """ __slots__ = ("_requests", "factories", "_ctx_request") def __init__( self, factories: dict[str, type[BaseRequest]] | None = None, ) -> None: self._requests: dict[str, BaseRequest] = {} self.factories = factories if factories is not None else REQUEST_FACTORIES.copy() self._ctx_request: ContextVar[BaseRequest | None] = ContextVar('current_request', default=None) @property def current(self) -> BaseRequest | None: """Current request from ContextVar.""" return self._ctx_request.get()
[docs] async def create( self, scope: Scope, receive: Receive, send: Send | None = None, **kwargs: Any, ) -> BaseRequest: """Create and register a request from ASGI scope.""" scope_type = scope.get("type", "") factory = self.factories.get(scope_type) if factory is None: raise ValueError(f"No factory for scope type: {scope_type!r}") request = factory() await request.init(scope, receive, send, **kwargs) self._requests[request.id] = request self._ctx_request.set(request) return request
[docs] def register_factory(self, scope_type: str, factory: type[BaseRequest]) -> None: """Register a factory for a scope type.""" self.factories[scope_type] = factory
[docs] def unregister(self) -> BaseRequest | None: """Unregister current request.""" request = self._ctx_request.get() if request is not None: self._requests.pop(request.id, None) self._ctx_request.set(None) return request
[docs] def get(self, request_id: str) -> BaseRequest | None: """Get a request by id.""" return self._requests.get(request_id)
[docs] def count_by_app(self, app_name: str) -> int: """Count active requests for a specific app.""" return sum(1 for req in self._requests.values() if req.app_name == app_name)
def __len__(self) -> int: """Return number of active requests.""" return len(self._requests) def __iter__(self) -> Iterator[BaseRequest]: """Iterate over active requests.""" return iter(self._requests.values()) def __contains__(self, request_id: str) -> bool: """Check if a request is registered.""" return request_id in self._requests def __repr__(self) -> str: return f"RequestRegistry(active={len(self._requests)})"
# Default factories for request creation REQUEST_FACTORIES: dict[str, type[BaseRequest]] = { "http": HttpRequest, "websocket": MsgRequest, }