Building the JAX Solver Tesseract for Lorenz-96ΒΆ
This examples demonstrates how the JAX solver Tesseract for the Lorenz-96 model is built for the purposes of the data assimilation demo.
Below is the input and output schema definition for the Lorenz Tesseract.
class InputSchema(BaseModel):
"""Input schema for forecasting of Lorenz 96 system."""
state: Differentiable[Array[(None,), Float32]] = Field(
description="A state vector for the Lorenz 96 system"
)
F: float = Field(description="Forcing parameter for Lorenz 96", default=8.0)
dt: float = Field(description="Time step for integration", default=0.05)
n_steps: int = Field(description="Number of integration steps", default=1)
class OutputSchema(BaseModel):
"""Output schema for forecasting of Lorenz 96 system."""
result: Differentiable[Array[(None, None), Float32]] = Field(
description="A trajectorie of predictions after integration"
)
Below is the implementation of the apply function, which takes in an initial condition and returns a trajectory of physical states.
def lorenz96_step(state: jnp.ndarray, F: float, dt: float) -> jnp.ndarray:
"""Perform one step of RK4 integration for the Lorenz 96 system."""
def lorenz96_derivatives(x: jnp.ndarray) -> jnp.ndarray:
"""Compute the derivatives for Lorenz 96 system."""
N = x.shape[0]
# Create arrays for indices with wraparound
ip1 = (jnp.arange(N) + 1) % N # i+1 with wraparound
im1 = (jnp.arange(N) - 1) % N # i-1 with wraparound
im2 = (jnp.arange(N) - 2) % N # i-2 with wraparound
# Compute derivatives: dx_i/dt = (x_{i+1} - x_{i-2}) * x_{i-1} - x_i + F
d = (x[ip1] - x[im2]) * x[im1] - x + F
return d
# RK4 integration
k1 = lorenz96_derivatives(state)
k2 = lorenz96_derivatives(state + dt * k1 / 2)
k3 = lorenz96_derivatives(state + dt * k2 / 2)
k4 = lorenz96_derivatives(state + dt * k3)
return state + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
def lorenz96_multi_step(
state: jnp.ndarray, F: float, dt: float, n_steps: int
) -> jnp.ndarray:
"""Perform multiple steps of Lorenz 96 integration using scan."""
def step_fn(state: jnp.ndarray, _: Any) -> tuple:
return lorenz96_step(state, F, dt), state
_, trajectory = jax.lax.scan(step_fn, state, None, length=n_steps)
return trajectory
@eqx.filter_jit
def apply_jit(inputs: dict) -> dict:
"""The apply_jit function for the Lorenz 96 tesseract."""
trajectory = lorenz96_multi_step(**inputs)
return dict(result=trajectory)
def apply(inputs: InputSchema) -> OutputSchema:
"""The apply function for the Lorenz 96 tesseract."""