# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0
"""WSGI Gateway middleware for ASGI applications.
Routes HTTP requests between ASGI and WSGI: if the path matches a known
ASGI route (apps or router children), the request passes through to the
Dispatcher. Everything else goes to the WSGI callable (GenroPy site)
mounted at the root.
The server exposes the WSGI callable via server.wsgi_app. When set,
this middleware is active. When None, all requests pass through normally.
"""
from __future__ import annotations
import io
import sys
from typing import TYPE_CHECKING, Any
from genro_toolbox.smartasync import smartasync # type: ignore[import-untyped]
from . import BaseMiddleware
if TYPE_CHECKING:
from ..types import ASGIApp, Receive, Scope, Send
class WsgiGatewayMiddleware(BaseMiddleware):
"""Gateway that sends non-ASGI requests to the WSGI app at root.
Logic is inverted: ASGI routes are the exception, WSGI is the default.
A request goes to ASGI only if the first path segment matches a known
ASGI app or router child. Everything else goes to the WSGI callable.
"""
middleware_name = "wsgi_gateway"
middleware_order = 50
middleware_default = True
__slots__ = ("_asgi_prefixes",)
def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
"""Args:
app: Next ASGI app in the chain (Dispatcher).
**kwargs: Middleware configuration from YAML.
"""
super().__init__(app, **kwargs)
self._asgi_prefixes: set[str] | None = None
def _get_asgi_prefixes(self) -> set[str]:
"""Lazy-build set of path prefixes handled by ASGI."""
if self._asgi_prefixes is None:
self._asgi_prefixes = set()
server = self.server
if server is not None:
for name, app_obj in server.apps.items():
if getattr(app_obj, "app_protocol", "asgi") == "asgi":
self._asgi_prefixes.add(name)
return self._asgi_prefixes
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Route: ASGI prefixes → Dispatcher, everything else → WSGI."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
server = self.server
wsgi_app = getattr(server, "wsgi_app", None) if server else None
if wsgi_app is None:
await self.app(scope, receive, send)
return
path = scope.get("path", "")
first_segment = path.strip("/").split("/", 1)[0] if path.strip("/") else ""
if first_segment in self._get_asgi_prefixes():
await self.app(scope, receive, send)
return
await self._handle_wsgi(scope, receive, send, wsgi_app)
async def _handle_wsgi(
self,
scope: Scope,
receive: Receive,
send: Send,
wsgi_app: Any,
) -> None:
"""Execute WSGI app and send response via ASGI.
Args:
scope: ASGI scope dict.
receive: ASGI receive callable (used to read body).
send: ASGI send callable (used to send response).
wsgi_app: PEP 3333 WSGI callable.
"""
body = await _read_body(receive)
environ = _build_environ(scope, body)
status_code, response_headers, response_body = await smartasync(_run_wsgi)(
wsgi_app, environ
)
await send(
{
"type": "http.response.start",
"status": status_code,
"headers": response_headers,
}
)
await send(
{
"type": "http.response.body",
"body": response_body,
}
)
async def _read_body(receive: Any) -> bytes:
"""Read full request body from ASGI receive.
Args:
receive: ASGI receive callable.
Returns:
Concatenated body bytes.
"""
chunks: list[bytes] = []
while True:
message = await receive()
chunk: bytes = message.get("body", b"")
if chunk:
chunks.append(chunk)
if not message.get("more_body", False):
break
return b"".join(chunks)
def _build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
"""Build PEP 3333 environ dict from ASGI scope and body.
Args:
scope: ASGI scope dict (type "http").
body: Raw request body bytes.
Returns:
WSGI environ dict conforming to PEP 3333.
"""
server = scope.get("server") or ("localhost", 80)
client = scope.get("client") or ("", 0)
environ: dict[str, Any] = {
"REQUEST_METHOD": scope.get("method", "GET"),
"SCRIPT_NAME": "",
"PATH_INFO": scope.get("path", "/"),
"QUERY_STRING": scope.get("query_string", b"").decode("latin-1"),
"SERVER_NAME": server[0],
"SERVER_PORT": str(server[1]),
"SERVER_PROTOCOL": f"HTTP/{scope.get('http_version', '1.1')}",
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": io.BytesIO(body),
"wsgi.errors": sys.stderr,
"wsgi.multithread": True,
"wsgi.multiprocess": False,
"wsgi.run_once": False,
}
if client[0]:
environ["REMOTE_ADDR"] = client[0]
for name, value in scope.get("headers", []):
header_name = name.decode("latin-1")
header_value = value.decode("latin-1")
if header_name == "content-type":
environ["CONTENT_TYPE"] = header_value
elif header_name == "content-length":
environ["CONTENT_LENGTH"] = header_value
else:
key = f"HTTP_{header_name.upper().replace('-', '_')}"
environ[key] = header_value
return environ
def _run_wsgi(
wsgi_callable: Any, environ: dict[str, Any]
) -> tuple[int, list[tuple[bytes, bytes]], bytes]:
"""Execute WSGI callable synchronously. Runs in thread via smartasync.
Args:
wsgi_callable: PEP 3333 WSGI application.
environ: WSGI environ dict.
Returns:
Tuple of (status_code, headers_as_bytes_pairs, response_body).
"""
status_holder: list[str] = []
headers_holder: list[list[tuple[str, str]]] = []
def start_response(
status: str,
response_headers: list[tuple[str, str]],
exc_info: Any = None,
) -> None:
status_holder.append(status)
headers_holder.append(response_headers)
result_iter = wsgi_callable(environ, start_response)
try:
body_parts = list(result_iter)
finally:
if hasattr(result_iter, "close"):
result_iter.close()
status_code = int(status_holder[0].split(" ", 1)[0])
asgi_headers: list[tuple[bytes, bytes]] = [
(name.lower().encode("latin-1"), value.encode("latin-1"))
for name, value in headers_holder[0]
]
body = b"".join(body_parts)
return status_code, asgi_headers, body
if __name__ == "__main__":
pass