Source code for genro_asgi.middleware.cors

# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0

"""CORS (Cross-Origin Resource Sharing) middleware for ASGI applications.

Adds CORS headers to HTTP responses, enabling cross-origin requests from
browsers. Handles preflight OPTIONS requests automatically.

Config:
    allow_origins (list|str): Origins allowed. Default: ["*"]
    allow_methods (list|str): HTTP methods allowed. Default: common methods
    allow_headers (list|str): Request headers allowed. Default: ["*"]
    allow_credentials (bool): Allow credentials (cookies). Default: False
    expose_headers (list|str): Response headers to expose. Default: []
    max_age (int): Preflight cache time in seconds. Default: 600

Note:
    When allow_credentials is True, cannot use "*" for origins - the
    actual origin is echoed back instead.

Example:
    Enable CORS in config.yaml::

        middleware:
          cors:
            allow_origins: ["https://example.com", "https://app.example.com"]
            allow_credentials: true
            max_age: 3600
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, MutableMapping

from . import BaseMiddleware
from ..utils import split_and_strip

if TYPE_CHECKING:
    from ..types import ASGIApp, Receive, Scope, Send


[docs] class CORSMiddleware(BaseMiddleware): """CORS middleware for HTTP requests. Handles preflight OPTIONS requests and adds CORS headers to responses. Non-HTTP requests pass through unchanged. Attributes: allow_origins: List of allowed origins. allow_methods: List of allowed HTTP methods. allow_headers: List of allowed request headers. allow_credentials: Whether to allow credentials. expose_headers: List of headers to expose to browser. max_age: Preflight response cache time in seconds. Class Attributes: middleware_name: "cors" - identifier for config. middleware_order: 300 - runs after auth middleware. middleware_default: False - disabled by default. """ middleware_name = "cors" middleware_order = 300 middleware_default = False __slots__ = ( "allow_origins", "allow_methods", "allow_headers", "allow_credentials", "expose_headers", "max_age", "_allow_all_origins", "_preflight_headers", )
[docs] def __init__( self, app: ASGIApp, allow_origins: str | list[str] | None = None, allow_methods: str | list[str] | None = None, allow_headers: str | list[str] | None = None, allow_credentials: bool = False, expose_headers: str | list[str] | None = None, max_age: int = 600, **kwargs: Any, ) -> None: """Initialize CORS middleware. Args: app: Next ASGI application in the middleware chain. allow_origins: Origins to allow. Accepts list or comma-separated string. Use "*" to allow all origins. Defaults to ["*"]. allow_methods: HTTP methods to allow. Defaults to common methods. allow_headers: Request headers to allow. Defaults to ["*"]. allow_credentials: Allow cookies/auth headers. Defaults to False. expose_headers: Response headers to expose to browser. Defaults to []. max_age: Preflight cache time in seconds. Defaults to 600. **kwargs: Additional arguments passed to BaseMiddleware. """ super().__init__(app, **kwargs) self.allow_origins = split_and_strip(allow_origins, ["*"]) self.allow_methods = split_and_strip( allow_methods, ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "HEAD"] ) self.allow_headers = split_and_strip(allow_headers, ["*"]) self.allow_credentials = allow_credentials self.expose_headers = split_and_strip(expose_headers) self.max_age = max_age self._allow_all_origins = "*" in self.allow_origins self._preflight_headers = self._build_preflight_headers()
def _build_preflight_headers(self) -> list[tuple[bytes, bytes]]: """Build static headers for preflight OPTIONS response. Returns: List of ASGI header tuples for preflight response. Note: Called once during __init__ and cached in _preflight_headers. Includes: Access-Control-Allow-Methods, Max-Age, Allow-Headers, and Allow-Credentials if enabled. """ headers = [ (b"access-control-allow-methods", ", ".join(self.allow_methods).encode()), (b"access-control-max-age", str(self.max_age).encode()), ] if self.allow_headers: if "*" in self.allow_headers: headers.append((b"access-control-allow-headers", b"*")) else: headers.append( (b"access-control-allow-headers", ", ".join(self.allow_headers).encode()) ) if self.allow_credentials: headers.append((b"access-control-allow-credentials", b"true")) return headers def _get_cors_headers(self, origin: str | None) -> list[tuple[bytes, bytes]]: """Get CORS headers for a response based on request origin. Args: origin: Origin header value from request, or None if not present. Returns: List of ASGI header tuples to add to response. Empty list if origin is not allowed or not present. Note: When allow_credentials is True and allow_all_origins is True, echoes the actual origin instead of "*" (per CORS spec). """ headers: list[tuple[bytes, bytes]] = [] if not origin: return headers # Check if origin is allowed if self._allow_all_origins: if self.allow_credentials: # Can't use * with credentials, must echo origin headers.append((b"access-control-allow-origin", origin.encode())) else: headers.append((b"access-control-allow-origin", b"*")) elif origin in self.allow_origins: headers.append((b"access-control-allow-origin", origin.encode())) headers.append((b"vary", b"Origin")) else: # Origin not allowed return [] if self.allow_credentials: headers.append((b"access-control-allow-credentials", b"true")) if self.expose_headers: headers.append( (b"access-control-expose-headers", ", ".join(self.expose_headers).encode()) ) return headers async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Process request with CORS handling. For HTTP requests: - Preflight OPTIONS: Returns 200 with CORS headers - Other requests: Wraps send to add CORS headers to response Args: scope: ASGI scope dictionary. receive: ASGI receive callable. send: ASGI send callable. Note: Non-HTTP requests pass through without CORS processing. """ if scope["type"] != "http": await self.app(scope, receive, send) return # Get origin from request headers origin = None for name, value in scope.get("headers", []): if name == b"origin": origin = value.decode("latin-1") break method = scope.get("method", "GET") # Handle preflight OPTIONS request if method == "OPTIONS" and origin: await self._handle_preflight(scope, receive, send, origin) return # Wrap send to add CORS headers cors_headers = self._get_cors_headers(origin) async def send_with_cors(message: MutableMapping[str, Any]) -> None: if message["type"] == "http.response.start" and cors_headers: headers = list(message.get("headers", [])) headers.extend(cors_headers) message = {**message, "headers": headers} await send(message) await self.app(scope, receive, send_with_cors) async def _handle_preflight( self, scope: Scope, receive: Receive, send: Send, origin: str ) -> None: """Handle preflight OPTIONS request. Args: scope: ASGI scope dictionary (unused but kept for consistency). receive: ASGI receive callable (unused). send: ASGI send callable for response. origin: Origin header value from request. Note: Returns 400 if origin is not allowed. Returns 200 with full CORS preflight headers if allowed. """ headers = self._get_cors_headers(origin) if not headers: # Origin not allowed await send({"type": "http.response.start", "status": 400, "headers": []}) await send({"type": "http.response.body", "body": b""}) return headers.extend(self._preflight_headers) await send( { "type": "http.response.start", "status": 200, "headers": headers, } ) await send({"type": "http.response.body", "body": b""})
if __name__ == "__main__": pass