Source code for genro_asgi.middleware.wsgi_gateway

# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0

"""WSGI Gateway middleware for ASGI applications.

Routes HTTP requests between ASGI and WSGI: if the path matches a known
ASGI route (apps or router children), the request passes through to the
Dispatcher. Everything else goes to the WSGI callable (GenroPy site)
mounted at the root.

The server exposes the WSGI callable via server.wsgi_app. When set,
this middleware is active. When None, all requests pass through normally.
"""

from __future__ import annotations

import io
import sys
from typing import TYPE_CHECKING, Any

from genro_toolbox.smartasync import smartasync  # type: ignore[import-untyped]

from . import BaseMiddleware

if TYPE_CHECKING:
    from ..types import ASGIApp, Receive, Scope, Send


class WsgiGatewayMiddleware(BaseMiddleware):
    """Gateway that sends non-ASGI requests to the WSGI app at root.

    Logic is inverted: ASGI routes are the exception, WSGI is the default.
    A request goes to ASGI only if the first path segment matches a known
    ASGI app or router child. Everything else goes to the WSGI callable.
    """

    middleware_name = "wsgi_gateway"
    middleware_order = 50
    middleware_default = True

    __slots__ = ("_asgi_prefixes",)

    def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
        """Args:
        app: Next ASGI app in the chain (Dispatcher).
        **kwargs: Middleware configuration from YAML.
        """
        super().__init__(app, **kwargs)
        self._asgi_prefixes: set[str] | None = None

    def _get_asgi_prefixes(self) -> set[str]:
        """Lazy-build set of path prefixes handled by ASGI."""
        if self._asgi_prefixes is None:
            self._asgi_prefixes = set()
            server = self.server
            if server is not None:
                for name, app_obj in server.apps.items():
                    if getattr(app_obj, "app_protocol", "asgi") == "asgi":
                        self._asgi_prefixes.add(name)
        return self._asgi_prefixes

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Route: ASGI prefixes → Dispatcher, everything else → WSGI."""
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        server = self.server
        wsgi_app = getattr(server, "wsgi_app", None) if server else None

        if wsgi_app is None:
            await self.app(scope, receive, send)
            return

        path = scope.get("path", "")
        first_segment = path.strip("/").split("/", 1)[0] if path.strip("/") else ""

        if first_segment in self._get_asgi_prefixes():
            await self.app(scope, receive, send)
            return

        await self._handle_wsgi(scope, receive, send, wsgi_app)

    async def _handle_wsgi(
        self,
        scope: Scope,
        receive: Receive,
        send: Send,
        wsgi_app: Any,
    ) -> None:
        """Execute WSGI app and send response via ASGI.

        Args:
            scope: ASGI scope dict.
            receive: ASGI receive callable (used to read body).
            send: ASGI send callable (used to send response).
            wsgi_app: PEP 3333 WSGI callable.
        """
        body = await _read_body(receive)
        environ = _build_environ(scope, body)

        status_code, response_headers, response_body = await smartasync(_run_wsgi)(
            wsgi_app, environ
        )

        await send(
            {
                "type": "http.response.start",
                "status": status_code,
                "headers": response_headers,
            }
        )
        await send(
            {
                "type": "http.response.body",
                "body": response_body,
            }
        )


async def _read_body(receive: Any) -> bytes:
    """Read full request body from ASGI receive.

    Args:
        receive: ASGI receive callable.

    Returns:
        Concatenated body bytes.
    """
    chunks: list[bytes] = []
    while True:
        message = await receive()
        chunk: bytes = message.get("body", b"")
        if chunk:
            chunks.append(chunk)
        if not message.get("more_body", False):
            break
    return b"".join(chunks)


def _build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
    """Build PEP 3333 environ dict from ASGI scope and body.

    Args:
        scope: ASGI scope dict (type "http").
        body: Raw request body bytes.

    Returns:
        WSGI environ dict conforming to PEP 3333.
    """
    server = scope.get("server") or ("localhost", 80)
    client = scope.get("client") or ("", 0)

    environ: dict[str, Any] = {
        "REQUEST_METHOD": scope.get("method", "GET"),
        "SCRIPT_NAME": "",
        "PATH_INFO": scope.get("path", "/"),
        "QUERY_STRING": scope.get("query_string", b"").decode("latin-1"),
        "SERVER_NAME": server[0],
        "SERVER_PORT": str(server[1]),
        "SERVER_PROTOCOL": f"HTTP/{scope.get('http_version', '1.1')}",
        "wsgi.version": (1, 0),
        "wsgi.url_scheme": scope.get("scheme", "http"),
        "wsgi.input": io.BytesIO(body),
        "wsgi.errors": sys.stderr,
        "wsgi.multithread": True,
        "wsgi.multiprocess": False,
        "wsgi.run_once": False,
    }

    if client[0]:
        environ["REMOTE_ADDR"] = client[0]

    for name, value in scope.get("headers", []):
        header_name = name.decode("latin-1")
        header_value = value.decode("latin-1")

        if header_name == "content-type":
            environ["CONTENT_TYPE"] = header_value
        elif header_name == "content-length":
            environ["CONTENT_LENGTH"] = header_value
        else:
            key = f"HTTP_{header_name.upper().replace('-', '_')}"
            environ[key] = header_value

    return environ


def _run_wsgi(
    wsgi_callable: Any, environ: dict[str, Any]
) -> tuple[int, list[tuple[bytes, bytes]], bytes]:
    """Execute WSGI callable synchronously. Runs in thread via smartasync.

    Args:
        wsgi_callable: PEP 3333 WSGI application.
        environ: WSGI environ dict.

    Returns:
        Tuple of (status_code, headers_as_bytes_pairs, response_body).
    """
    status_holder: list[str] = []
    headers_holder: list[list[tuple[str, str]]] = []

    def start_response(
        status: str,
        response_headers: list[tuple[str, str]],
        exc_info: Any = None,
    ) -> None:
        status_holder.append(status)
        headers_holder.append(response_headers)

    result_iter = wsgi_callable(environ, start_response)

    try:
        body_parts = list(result_iter)
    finally:
        if hasattr(result_iter, "close"):
            result_iter.close()

    status_code = int(status_holder[0].split(" ", 1)[0])

    asgi_headers: list[tuple[bytes, bytes]] = [
        (name.lower().encode("latin-1"), value.encode("latin-1"))
        for name, value in headers_holder[0]
    ]

    body = b"".join(body_parts)

    return status_code, asgi_headers, body


if __name__ == "__main__":
    pass