Get started

Quick start

Note

Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).

See also

For more detailed installation instructions, please refer to the Tesseract Core documentation.

  1. Install Tesseract-JAX:

    $ pip install tesseract-jax
    
  2. Build an example Tesseract:

    $ git clone https://github.com/pasteurlabs/tesseract-jax
    $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
    
  3. Use it as part of a JAX program:

    import jax
    import jax.numpy as jnp
    from tesseract_core import Tesseract
    from tesseract_jax import apply_tesseract
    
    # Load the Tesseract
    t = Tesseract.from_image("vectoradd_jax")
    t.serve()
    
    # Run it with JAX
    x = jnp.ones((1000,))
    y = jnp.ones((1000,))
    
    def vector_sum(x, y):
        res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
        return res["vector_add"]["result"].sum()
    
    vector_sum(x, y) # success!
    
    # You can also use it with JAX transformations like JIT and grad
    vector_sum_jit = jax.jit(vector_sum)
    vector_sum_jit(x, y)
    
    vector_sum_grad = jax.grad(vector_sum)
    vector_sum_grad(x, y)
    

Tip

Now you’re ready to jump into our examples for ways to use Tesseract-JAX.

Sharp edges

  • Arrays vs. array-like objects: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.

    from tesseract_core import Tesseract
    from tesseract_jax import apply_tesseract
    
    tess = Tesseract.from_image("vectoradd_jax")
    with Tesseract.from_image("vectoradd_jax") as tess:
        apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}})  # ❌ raises an error
        apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}})  # ✅ works
    
  • Additional required endpoints: Tesseract-JAX requires the abstract_eval Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like jax.grad require vector_jacobian_product to be defined.

Tip

When creating a new Tesseract based on a JAX function, use tesseract init --recipe jax to define all required endpoints automatically, including abstract_eval and vector_jacobian_product.