Building the JAX Solver Tesseract for Lorenz-96ΒΆ

This examples demonstrates how the JAX solver Tesseract for the Lorenz-96 model is built for the purposes of the data assimilation demo.

Below is the input and output schema definition for the Lorenz Tesseract.

class InputSchema(BaseModel):
    """Input schema for forecasting of Lorenz 96 system."""

    state: Differentiable[Array[(None,), Float32]] = Field(
        description="A state vector for the Lorenz 96 system"
    )
    F: float = Field(description="Forcing parameter for Lorenz 96", default=8.0)
    dt: float = Field(description="Time step for integration", default=0.05)
    n_steps: int = Field(description="Number of integration steps", default=1)
class OutputSchema(BaseModel):
    """Output schema for forecasting of Lorenz 96 system."""

    result: Differentiable[Array[(None, None), Float32]] = Field(
        description="A trajectorie of predictions after integration"
    )

Below is the implementation of the apply function, which takes in an initial condition and returns a trajectory of physical states.

def lorenz96_step(state: jnp.ndarray, F: float, dt: float) -> jnp.ndarray:
    """Perform one step of RK4 integration for the Lorenz 96 system."""

    def lorenz96_derivatives(x: jnp.ndarray) -> jnp.ndarray:
        """Compute the derivatives for Lorenz 96 system."""
        N = x.shape[0]
        # Create arrays for indices with wraparound
        ip1 = (jnp.arange(N) + 1) % N  # i+1 with wraparound
        im1 = (jnp.arange(N) - 1) % N  # i-1 with wraparound
        im2 = (jnp.arange(N) - 2) % N  # i-2 with wraparound
        # Compute derivatives: dx_i/dt = (x_{i+1} - x_{i-2}) * x_{i-1} - x_i + F
        d = (x[ip1] - x[im2]) * x[im1] - x + F
        return d

    # RK4 integration
    k1 = lorenz96_derivatives(state)
    k2 = lorenz96_derivatives(state + dt * k1 / 2)
    k3 = lorenz96_derivatives(state + dt * k2 / 2)
    k4 = lorenz96_derivatives(state + dt * k3)
    return state + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6


def lorenz96_multi_step(
    state: jnp.ndarray, F: float, dt: float, n_steps: int
) -> jnp.ndarray:
    """Perform multiple steps of Lorenz 96 integration using scan."""

    def step_fn(state: jnp.ndarray, _: Any) -> tuple:
        return lorenz96_step(state, F, dt), state

    _, trajectory = jax.lax.scan(step_fn, state, None, length=n_steps)
    return trajectory


@eqx.filter_jit
def apply_jit(inputs: dict) -> dict:
    """The apply_jit function for the Lorenz 96 tesseract."""
    trajectory = lorenz96_multi_step(**inputs)
    return dict(result=trajectory)


def apply(inputs: InputSchema) -> OutputSchema:
    """The apply function for the Lorenz 96 tesseract."""