Get started¶
Quick start¶
Note
Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).
See also
For more detailed installation instructions, please refer to the Tesseract Core documentation.
Install Tesseract-JAX:
$ pip install tesseract-jax
Build an example Tesseract:
$ git clone https://github.com/pasteurlabs/tesseract-jax $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
Use it as part of a JAX program:
import jax import jax.numpy as jnp from tesseract_core import Tesseract from tesseract_jax import apply_tesseract # Load the Tesseract t = Tesseract.from_image("vectoradd_jax") t.serve() # Run it with JAX x = jnp.ones((1000,)) y = jnp.ones((1000,)) def vector_sum(x, y): res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}, vmap_method="sequential") return res["vector_add"]["result"].sum() vector_sum(x, y) # success! # You can also use it with JAX transformations like JIT and grad vector_sum_jit = jax.jit(vector_sum) vector_sum_jit(x, y) vector_sum_grad = jax.grad(vector_sum) vector_sum_grad(x, y) # vmap requires an explicit vmap_method — "sequential" is safe but slow # while "auto_experimental" or "expand_dims" is more efficient for Tesseracts that support batching. vector_sum_vmap = jax.vmap(vector_sum) vector_sum_vmap(x.reshape(10, 100), y.reshape(10, 100))
See also
See Batching strategies for jax.vmap for a guide on selecting the appropriate vmap_method.
Tip
Now you’re ready to jump into our examples for ways to use Tesseract-JAX.