Basic example: vector addition¶
Note
All examples are expected to run from the examples/<example_name>
directory of the Tesseract-JAX repository.
Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable, differentiable, and composable.
In this example, you will learn how to:
Build a Tesseract that performs vector addition.
Access its endpoints via Tesseract-JAX’s
apply_tesseract()
function.Compose Tesseracts into more complex functions, blending multiple Tesseract applications with local operations.
Apply
jax.jit
to the resulting pipeline to perform JIT compilation, and / or autodifferentiate the function (viajax.grad
,jax.jvp
,jax.vjp
, …).
Step 1: Build + serve example Tesseract¶
In this example, we build and use a Tesseract that performs vector addition. The example Tesseract takes two vectors and scalars as input and return some statistics as output. Here is the functionality that’s implemented in the Tesseract (see vectoradd_jax/tesseract_api.py
):
def apply_jit(inputs: dict) -> dict:
a_scaled = inputs["a"]["s"] * inputs["a"]["v"]
b_scaled = inputs["b"]["s"] * inputs["b"]["v"]
add_result = a_scaled + b_scaled
min_result = a_scaled - b_scaled
def safe_norm(x, ord):
# Compute the norm of a vector, adding a small epsilon to ensure
# differentiability and avoid division by zero
return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord)
return {
"vector_add": {
"result": add_result,
"normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]),
},
"vector_min": {
"result": min_result,
"normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]),
},
}
You may build the example Tesseract either via the command line, or running the cell below (you can skip running this if already built).
%%bash
# Build vectoradd_jax Tesseract so we can use it below
tesseract build vectoradd_jax/
[i] Building image ...
⠙ Processing
[i] Built image sha256:7ae85ba85970, ['vectoradd_jax:latest']
["vectoradd_jax:latest"]
To interact with the Tesseract, we use the Python SDK from tesseract_core
to load the built image and start a server container.
from tesseract_core import Tesseract
vectoradd = Tesseract.from_image("vectoradd_jax")
vectoradd.serve()
Step 2: Invoke the Tesseract via Tesseract-JAX¶
Using the vectoradd_jax
Tesseract image we built earlier, let’s add two vectors together, representing the following operation:
We can perform this calculation using the function tesseract_jax.apply_tesseract()
, by passing the Tesseract
object and the input data as a PyTree (nested dictionary) of JAX arrays as inputs.
from pprint import pprint
import jax
import jax.numpy as jnp
from tesseract_jax import apply_tesseract
a = {"v": jnp.array([1.0, 2.0, 3.0], dtype="float32")}
b = {
"v": jnp.array([4.0, 5.0, 6.0], dtype="float32"),
"s": jnp.array(2.0, dtype="float32"),
}
outputs = apply_tesseract(vectoradd, inputs={"a": a, "b": b})
pprint(outputs)
{'vector_add': {'normed_result': Array([0.42426407, 0.56568545, 0.70710677], dtype=float32),
'result': Array([ 9., 12., 15.], dtype=float32)},
'vector_min': {'normed_result': Array([-0.5025707 , -0.5743665 , -0.64616233], dtype=float32),
'result': Array([-7., -8., -9.], dtype=float32)}}
As expected, outputs['vector_add']
gives a value of \((9, 12, 15)\).
Step 3: Function composition via Tesseracts¶
Tesseract-JAX enables you to compose chains of Tesseract evaluations, blended with local operations, while retaining all the benefits of JAX.
The function below applies vectoradd
twice, ie. \((\mathbf{a} + \mathbf{b}) + \mathbf{a}\), then performs local arithmetic on the outputs, applies vectoradd
once more, and finally returns a single element of the result. The resulting function is still a valid JAX function, and is fully jittable and auto-differentiable.
def fancy_operation(a: jax.Array, b: jax.Array) -> jnp.float32:
"""Fancy operation."""
result = apply_tesseract(vectoradd, inputs={"a": a, "b": b})
result = apply_tesseract(
vectoradd, inputs={"a": {"v": result["vector_add"]["result"]}, "b": b}
)
# We can mix and match with local JAX operations
result = 2.0 * result["vector_add"]["normed_result"] + b["v"]
result = apply_tesseract(vectoradd, inputs={"a": {"v": result}, "b": b})
return result["vector_add"]["result"][1]
fancy_operation(a, b)
Array(16.135319, dtype=float32)
This is compatible with jax.jit()
:
jitted_op = jax.jit(fancy_operation)
jitted_op(a, b)
Array(16.135319, dtype=float32)
Autodifferentiation is automatically dispatched to the underlying Tesseract’s jacobian_vector_product
and vector_jacobian_product
endpoints, and works as expected:
# jax.grad for reverse-mode autodiff (scalar outputs only)
grad_res = jax.grad(fancy_operation, argnums=[0, 1])(a, b)
print("jax.grad result:")
pprint(grad_res)
# jax.jvp for general forward-mode autodiff
_, jvp = jax.jvp(fancy_operation, (a, b), (a, b))
print("\njax.jvp result:")
pprint(jvp)
# jax.vjp for general reverse-mode autodiff
_, vjp_fn = jax.vjp(fancy_operation, a, b)
vjp = vjp_fn(1.0)
print("\njax.vjp result:")
pprint(vjp)
jax.grad result:
({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},
{'s': Array(5.002062, dtype=float32),
'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})
jax.jvp result:
Array(25.004124, dtype=float32)
jax.vjp result:
({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},
{'s': Array(5.002062, dtype=float32),
'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})
All the above also works when combining with jit
:
# jax.grad for reverse-mode autodiff (scalar output)
grad_res = jax.jit(jax.grad(fancy_operation, argnums=[0, 1]))(a, b)
print("jax.grad result:")
pprint(grad_res)
jax.grad result:
({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},
{'s': Array(5.002062, dtype=float32),
'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})
Step N+1: Clean-up and conclusions¶
Since we kept the Tesseract alive using .serve()
, we need to manually stop it using .teardown()
to avoid leaking resources.
This is not necessary when using Tesseract
in a with
statement, as it will automatically clean up when the context is exited.
vectoradd.teardown()
And that’s it! You’ve learned how to build up differentiable pipelines with Tesseracts that blend seamlessly with JAX’s APIs and transformations.