# Copyright 2025 Softwell S.r.l.
# Licensed under the Apache License, Version 2.0
"""Session store — protocol and in-memory implementation."""
from __future__ import annotations
import secrets
from typing import Any, Protocol, runtime_checkable
from .session import Session
__all__ = ["SessionStore", "MemorySessionStore"]
[docs]
@runtime_checkable
class SessionStore(Protocol):
"""Protocol for session storage backends."""
[docs]
def get(self, session_id: str) -> Session | None:
"""Retrieve session by id, or None if not found/expired."""
...
[docs]
def create(self, auth: dict[str, Any] | None = None) -> Session:
"""Create a new session with unique token."""
...
[docs]
def delete(self, session_id: str) -> None:
"""Remove session from store."""
...
[docs]
def dump(self) -> dict[str, Any]:
"""Serialize all sessions for persistence."""
...
[docs]
def restore(self, data: dict[str, Any]) -> None:
"""Restore sessions from serialized data."""
...
[docs]
class MemorySessionStore:
"""In-memory session store. Default implementation.
Attributes:
_sessions: Dict mapping session_id to Session.
_default_ttl: Default TTL in seconds for new sessions.
"""
__slots__ = ("_sessions", "_default_ttl")
[docs]
def __init__(self, default_ttl: int = 3600) -> None:
"""Initialize store.
Args:
default_ttl: Default time-to-live in seconds for new sessions.
"""
self._sessions: dict[str, Session] = {}
self._default_ttl = default_ttl
[docs]
def get(self, session_id: str) -> Session | None:
"""Retrieve session by id. Returns None if not found or expired."""
session = self._sessions.get(session_id)
if session is None:
return None
if session.is_expired():
del self._sessions[session_id]
return None
session.touch()
return session
[docs]
def create(self, auth: dict[str, Any] | None = None) -> Session:
"""Create a new session with unique token.
Args:
auth: Auth dict snapshot to store in session.
Returns:
New Session instance.
"""
session_id = secrets.token_urlsafe(32)
session = Session(session_id=session_id, auth=auth, ttl=self._default_ttl)
self._sessions[session_id] = session
return session
[docs]
def delete(self, session_id: str) -> None:
"""Remove session from store."""
self._sessions.pop(session_id, None)
[docs]
def dump(self) -> dict[str, Any]:
"""Serialize all sessions for persistence."""
result: dict[str, Any] = {}
for session_id, session in self._sessions.items():
result[session_id] = {
"meta": dict(session.meta),
"auth": session.auth,
}
return result
[docs]
def restore(self, data: dict[str, Any]) -> None:
"""Restore sessions from serialized data.
Args:
data: Dict from dump() output.
"""
for session_id, session_data in data.items():
meta = session_data["meta"]
session = Session(
session_id=session_id,
auth=session_data.get("auth"),
ttl=meta["ttl"],
)
session.meta["created_at"] = meta["created_at"]
session.meta["last_access"] = meta["last_access"]
if not session.is_expired():
self._sessions[session_id] = session
if __name__ == "__main__":
pass