Gradient-based optimization of fluid flows

Note

All examples are expected to run from the examples/<example_name> directory of the Tesseract-JAX repository.

In this example, you will learn how to:

  1. Build a Tesseract that wraps a differentiable simulator from JAX-CFD.

  2. Access its endpoints via Tesseract-JAX’s apply_tesseract() function.

  3. Perform gradient-based optimization of the fluid simulation (via scipy.optimize.minimize), using the Tesseract as a differentiable simulator.

The goal of this application is to find initial conditions for which the final fluid flow is close to this image (the Pasteur Labs logo):

from IPython.display import Image

Image(filename="pl.png", width=200)
# Install additional requirements for this notebook
%pip install -r requirements.txt -q
[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.

Step 1: Build + serve JAX-CFD Tesseract

cfd-tesseract is a differentiable Navier-Stokes solver based on JAX-CFD that is wrapped in a Tesseract.

Here is its apply function, as defined in cfd_tesseract/tesseract_api.py:

def cfd_fwd(
    v0: jnp.ndarray,
    density: float,
    viscosity: float,
    inner_steps: int,
    outer_steps: int,
    max_velocity: float,
    cfl_safety_factor: float,
    domain_size_x: float,
    domain_size_y: float,
) -> tuple[jax.Array, jax.Array]:
    """Compute the final velocity field using the semi-implicit Navier-Stokes equations.

    Args:
        v0: Initial velocity field.
        density: Density of the fluid.
        viscosity: Viscosity of the fluid.
        inner_steps: Number of solver steps for each timestep.
        outer_steps: Number of timesteps steps.
        max_velocity: Maximum velocity.
        cfl_safety_factor: CFL safety factor.
        domain_size_x: Domain size in x direction.
        domain_size_y: Domain size in y direction.

    Returns:
        Final velocity field.
    """
    vx0 = v0[..., 0]
    vy0 = v0[..., 1]
    bc = cfd.boundaries.HomogeneousBoundaryConditions(
        (
            (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),
            (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),
        )
    )

    # reconstruct grid from input
    grid = cfd.grids.Grid(
        vx0.shape, domain=((0.0, domain_size_x), (0.0, domain_size_y))
    )

    vx0 = cfd.grids.GridArray(vx0, grid=grid, offset=(1.0, 0.5))
    vy0 = cfd.grids.GridArray(vy0, grid=grid, offset=(0.5, 1.0))

    # reconstruct GridVariable from input
    vx0 = cfd.grids.GridVariable(vx0, bc)
    vy0 = cfd.grids.GridVariable(vy0, bc)
    v0 = (vx0, vy0)

    # Choose a time step.
    dt = cfd.equations.stable_time_step(
        max_velocity, cfl_safety_factor, viscosity, grid
    )

    # Define a step function and use it to compute a trajectory.
    step_fn = cfd.funcutils.repeated(
        cfd.equations.semi_implicit_navier_stokes(
            density=density, viscosity=viscosity, dt=dt, grid=grid
        ),
        steps=inner_steps,
    )
    rollout_fn = cfd.funcutils.trajectory(step_fn, outer_steps)
    _, trajectory = jax.device_get(rollout_fn(v0))
    vxn = trajectory[0].array.data[-1]
    vyn = trajectory[1].array.data[-1]
    return jnp.stack([vxn, vyn], axis=-1)

To build the Tesseract, we use the tesseract command line tool.

%%bash
# Build CFD Tesseract so we can use it below
tesseract build cfd-tesseract/
?25l [i] Building image ...
 Processing
 [i] Built image sha256:7632fcec515b, ['jax-cfd:latest']
["jax-cfd:latest"]

To interact with the Tesseract, we use the Python SDK from tesseract_core to load the built image and start a server container.

from tesseract_core import Tesseract

cfd_tesseract = Tesseract.from_image("jax-cfd")
cfd_tesseract.serve()

Step 2: Test forward evaluation with Tesseract-JAX

Let’s set up the Tesseract with Tesseract-JAX and test a simple forward evaluation. First, we’ll define an initial guess for the velocity field over a grid. The resulting vx and vy give our horizontal and vertical velocity fields, respectively.

# Import necessary libraries
import jax
import jax.numpy as jnp
import jax_cfd.base as cfd
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image as PILImage
from scipy.optimize import minimize
from tqdm import tqdm

from tesseract_jax import apply_tesseract

# Set up the CFD simulation parameters
seed = 0
size = 64
max_velocity = 3.0
domain_size_x = jnp.pi * 2
domain_size_y = jnp.pi * 2

bc = cfd.boundaries.HomogeneousBoundaryConditions(
    (
        (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),
        (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),
    )
)

grid = cfd.grids.Grid((size, size), domain=((0, domain_size_x), (0, domain_size_y)))
v0 = cfd.initial_conditions.filtered_velocity_field(
    jax.random.PRNGKey(seed), grid, max_velocity
)
vx, vy = v0

params = {
    "density": 1.0,
    "viscosity": 0.01,
    "inner_steps": 25,
    "outer_steps": 30,
    "max_velocity": max_velocity,
    "cfl_safety_factor": 0.5,
    "domain_size_x": domain_size_x,
    "domain_size_y": domain_size_y,
}

# Define initial velocity field
v0 = np.stack([np.array(vx.array.data), np.array(vy.array.data)], axis=-1)


# Define the Tesseract function
def cfd_tesseract_fn(v0):
    return apply_tesseract(cfd_tesseract, inputs=dict(v0=v0, **params))


# Apply Tesseract to the initial velocity field
outputs = cfd_tesseract_fn(v0)

Using the results of the forward pass, we can set up a basic approach for visualising our velocity field. We’ll use matplotlib to show the \(x\) and \(y\) components of the velocity as heatmaps.

vxn = outputs["result"][..., 0]
vyn = outputs["result"][..., 1]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(vxn, cmap="viridis")
ax[0].set_title("vx")
ax[1].imshow(vyn, cmap="viridis")
ax[1].set_title("vy")
Text(0.5, 1.0, 'vy')
../_images/13658274cd1aa4f41b0f3563c777cde30fbd9fff3949c7d4cb60a1e4d5b7920f.png

Next we define a vorticity function for later use (recalling that vorticity is the curl of the flow velocity, ie. \(\omega = \nabla \times \mathbf{v}\)).

def vorticity(vxn, vyn):
    vxn = cfd.grids.GridArray(vxn, grid=grid, offset=(1.0, 0.5))
    vyn = cfd.grids.GridArray(vyn, grid=grid, offset=(0.5, 1.0))

    # reconstrut GridVariable from input
    vxn = cfd.grids.GridVariable(vxn, bc)
    vyn = cfd.grids.GridVariable(vyn, bc)

    # differntiate
    _, dvx_dy = cfd.finite_differences.central_difference(vxn)
    dvy_dx, _ = cfd.finite_differences.central_difference(vyn)

    return dvy_dx.data - dvx_dy.data