Basic example: vector addition

Note

All examples are expected to run from the examples/<example_name> directory of the Tesseract-JAX repository.

Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable, differentiable, and composable.

In this example, you will learn how to:

  1. Build a Tesseract that performs vector addition.

  2. Access its endpoints via Tesseract-JAX’s apply_tesseract() function.

  3. Compose Tesseracts into more complex functions, blending multiple Tesseract applications with local operations.

  4. Apply jax.jit to the resulting pipeline to perform JIT compilation, and / or autodifferentiate the function (via jax.grad, jax.jvp, jax.vjp, …).

Step 1: Build + serve example Tesseract

In this example, we build and use a Tesseract that performs vector addition. The example Tesseract takes two vectors and scalars as input and return some statistics as output. Here is the functionality that’s implemented in the Tesseract (see vectoradd_jax/tesseract_api.py):

def apply_jit(inputs: dict) -> dict:
    a_scaled = inputs["a"]["s"] * inputs["a"]["v"]
    b_scaled = inputs["b"]["s"] * inputs["b"]["v"]
    add_result = a_scaled + b_scaled
    min_result = a_scaled - b_scaled

    def safe_norm(x, ord):
        # Compute the norm of a vector, adding a small epsilon to ensure
        # differentiability and avoid division by zero
        return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord)

    return {
        "vector_add": {
            "result": add_result,
            "normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]),
        },
        "vector_min": {
            "result": min_result,
            "normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]),
        },
    }

You may build the example Tesseract either via the command line, or running the cell below (you can skip running this if already built).

%%bash
# Build vectoradd_jax Tesseract so we can use it below
tesseract build vectoradd_jax/
 [i] Building image ...
 Processing
 [i] Built image sha256:7ae85ba85970, ['vectoradd_jax:latest']
["vectoradd_jax:latest"]

To interact with the Tesseract, we use the Python SDK from tesseract_core to load the built image and start a server container.

from tesseract_core import Tesseract

vectoradd = Tesseract.from_image("vectoradd_jax")
vectoradd.serve()

Step 2: Invoke the Tesseract via Tesseract-JAX

Using the vectoradd_jax Tesseract image we built earlier, let’s add two vectors together, representing the following operation:

\[\begin{split}\begin{pmatrix} 1 \\ 2 \\ 3 \end{pmatrix} + 2 \cdot \begin{pmatrix} 4 \\ 5 \\ 6 \end{pmatrix} = \begin{pmatrix} 9 \\ 12 \\ 15 \end{pmatrix}\end{split}\]

We can perform this calculation using the function tesseract_jax.apply_tesseract(), by passing the Tesseract object and the input data as a PyTree (nested dictionary) of JAX arrays as inputs.

from pprint import pprint

import jax
import jax.numpy as jnp

from tesseract_jax import apply_tesseract

a = {"v": jnp.array([1.0, 2.0, 3.0], dtype="float32")}
b = {
    "v": jnp.array([4.0, 5.0, 6.0], dtype="float32"),
    "s": jnp.array(2.0, dtype="float32"),
}

outputs = apply_tesseract(vectoradd, inputs={"a": a, "b": b})
pprint(outputs)
{'vector_add': {'normed_result': Array([0.42426407, 0.56568545, 0.70710677], dtype=float32),
                'result': Array([ 9., 12., 15.], dtype=float32)},
 'vector_min': {'normed_result': Array([-0.5025707 , -0.5743665 , -0.64616233], dtype=float32),
                'result': Array([-7., -8., -9.], dtype=float32)}}

As expected, outputs['vector_add'] gives a value of \((9, 12, 15)\).

Step 3: Function composition via Tesseracts

Tesseract-JAX enables you to compose chains of Tesseract evaluations, blended with local operations, while retaining all the benefits of JAX.

The function below applies vectoradd twice, ie. \((\mathbf{a} + \mathbf{b}) + \mathbf{a}\), then performs local arithmetic on the outputs, applies vectoradd once more, and finally returns a single element of the result. The resulting function is still a valid JAX function, and is fully jittable and auto-differentiable.

def fancy_operation(a: jax.Array, b: jax.Array) -> jnp.float32:
    """Fancy operation."""
    result = apply_tesseract(vectoradd, inputs={"a": a, "b": b})
    result = apply_tesseract(
        vectoradd, inputs={"a": {"v": result["vector_add"]["result"]}, "b": b}
    )
    # We can mix and match with local JAX operations
    result = 2.0 * result["vector_add"]["normed_result"] + b["v"]
    result = apply_tesseract(vectoradd, inputs={"a": {"v": result}, "b": b})
    return result["vector_add"]["result"][1]


fancy_operation(a, b)
Array(16.135319, dtype=float32)

This is compatible with jax.jit():

jitted_op = jax.jit(fancy_operation)
jitted_op(a, b)
Array(16.135319, dtype=float32)

Autodifferentiation is automatically dispatched to the underlying Tesseract’s jacobian_vector_product and vector_jacobian_product endpoints, and works as expected:

# jax.grad for reverse-mode autodiff (scalar outputs only)
grad_res = jax.grad(fancy_operation, argnums=[0, 1])(a, b)
print("jax.grad result:")
pprint(grad_res)

# jax.jvp for general forward-mode autodiff
_, jvp = jax.jvp(fancy_operation, (a, b), (a, b))
print("\njax.jvp result:")
pprint(jvp)

# jax.vjp for general reverse-mode autodiff
_, vjp_fn = jax.vjp(fancy_operation, a, b)
vjp = vjp_fn(1.0)
print("\njax.vjp result:")
pprint(vjp)
jax.grad result:
({'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)},
 {'s': Array(5.002062, dtype=float32),
  'v': Array([-0.05139923,  3.139905  , -0.08163408], dtype=float32)})

jax.jvp result:
Array(25.004124, dtype=float32)

jax.vjp result:
({'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)},
 {'s': Array(5.002062, dtype=float32),
  'v': Array([-0.05139923,  3.139905  , -0.08163408], dtype=float32)})

All the above also works when combining with jit:

# jax.grad for reverse-mode autodiff (scalar output)
grad_res = jax.jit(jax.grad(fancy_operation, argnums=[0, 1]))(a, b)
print("jax.grad result:")
pprint(grad_res)
jax.grad result:
({'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)},
 {'s': Array(5.002062, dtype=float32),
  'v': Array([-0.05139923,  3.139905  , -0.08163408], dtype=float32)})

Step N+1: Clean-up and conclusions

Since we kept the Tesseract alive using .serve(), we need to manually stop it using .teardown() to avoid leaking resources.

This is not necessary when using Tesseract in a with statement, as it will automatically clean up when the context is exited.

vectoradd.teardown()

And that’s it! You’ve learned how to build up differentiable pipelines with Tesseracts that blend seamlessly with JAX’s APIs and transformations.