Source code for monawhat.state

"""Module implementing the State monad for handling stateful computations.

Examples:
    Basic usage of State monad:

    >>> # Create a state that increments a counter and returns the old value
    >>> def increment() -> State[int, int]:
    ...     return State(lambda s: (s, s + 1))
    ...
    >>> # Chain multiple state operations
    >>> computation = increment().bind(
    ...     lambda x: increment().bind(
    ...         lambda y: State.pure(x + y)
    ...     )
    ... )
    >>> # Run the computation with initial state 0
    >>> result, final_state = computation.run(0)
    >>> print(result)  # Sum of first two states: 0 + 1 = 1
    1
    >>> print(final_state)  # Final counter value: 2

    Using helper methods:
    
    >>> # Get current state
    >>> get_state = State.get()
    >>> value, state = get_state.run(42)
    >>> assert value == state == 42
    
    >>> # Modify state
    >>> add_one = State.modify(lambda x: x + 1)
    >>> _, new_state = add_one.run(10)
    >>> assert new_state == 11
"""

from typing import TypeVar, Generic
from collections.abc import Callable

from monawhat.base import BaseMonad

S = TypeVar("S")  # State type
A = TypeVar("A")  # Input type
B = TypeVar("B")  # Output type


[docs] class State(Generic[S, A], BaseMonad[A]): """State monad for computations that manipulate state. The State monad represents a stateful computation that produces a value and potentially modifies some state. It's useful for modeling computations with mutable state in a purely functional way. A State[S, A] instance represents a function from state S to a tuple (A, S), where A is the result value and S is the new state. """ def __init__(self, run_fn: Callable[[S], tuple[A, S]]) -> None: """Initialize a State monad with a state transition function. Args: run_fn: A function that takes a state and returns a tuple of (value, new_state) """ self._run_fn = run_fn
[docs] def run(self, initial_state: S) -> tuple[A, S]: """Execute the stateful computation with the given initial state. Args: initial_state: The starting state Returns: A tuple containing the result value and the final state """ return self._run_fn(initial_state)
[docs] def eval(self, initial_state: S) -> A: """Execute the computation and return only the result value. Args: initial_state: The starting state Returns: The result value """ value, _ = self.run(initial_state) return value
[docs] def exec(self, initial_state: S) -> S: """Execute the computation and return only the final state. Args: initial_state: The starting state Returns: The final state """ _, final_state = self.run(initial_state) return final_state
def _map_implementation(self, f: Callable[[A], B]) -> "State[S, B]": """Apply a function to the result value while preserving the state. Args: f: The function to apply to the result value Returns: A new State that applies the function to the result """ def new_run(s: S) -> tuple[B, S]: value, new_state = self.run(s) return f(value), new_state return State(new_run) def _bind_implementation(self, f: Callable[[A], "State[S, B]"]) -> "State[S, B]": """Chain this State with a function that returns another State. Args: f: A function that takes the result of this State and returns a new State Returns: A new State representing the chained computation """ def new_run(s: S) -> tuple[B, S]: value, intermediate_state = self.run(s) return f(value).run(intermediate_state) return State(new_run) @classmethod def _pure_implementation(cls, value: A) -> "State[S, A]": """Create a State that returns the value without modifying the state. Args: value: The value to return Returns: A State that leaves the state unchanged and returns the value """ return State(lambda s: (value, s))
[docs] @staticmethod def get() -> "State[S, S]": """Create a State that returns the current state as its value. Returns: A State that returns the current state as its value """ return State(lambda s: (s, s))
[docs] @staticmethod def gets(f: Callable[[S], A]) -> "State[S, A]": """Create a State that applies a function to the current state to produce a value. Args: f: A function that computes a value from the state Returns: A State that applies the function to the current state """ return State(lambda s: (f(s), s))
[docs] @staticmethod def put(new_state: S) -> "State[S, None]": """Create a State that replaces the current state. Args: new_state: The new state to use Returns: A State that replaces the current state """ return State(lambda _: (None, new_state))
[docs] @staticmethod def modify(f: Callable[[S], S]) -> "State[S, None]": """Create a State that modifies the current state using a function. Args: f: A function that transforms the state Returns: A State that applies the function to the current state """ return State(lambda s: (None, f(s)))
[docs] def __repr__(self) -> str: """Return a string representation of the State.""" return f"State({self._run_fn})"