Get started

Quick start

Note

Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).

See also

For more detailed installation instructions, please refer to the Tesseract Core documentation.

  1. Install Tesseract-JAX:

    $ pip install tesseract-jax
    
  2. Build an example Tesseract:

    $ git clone https://github.com/pasteurlabs/tesseract-jax
    $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
    
  3. Use it as part of a JAX program:

    import jax
    import jax.numpy as jnp
    from tesseract_core import Tesseract
    from tesseract_jax import apply_tesseract
    
    # Load the Tesseract
    t = Tesseract.from_image("vectoradd_jax")
    t.serve()
    
    # Run it with JAX
    x = jnp.ones((1000,))
    y = jnp.ones((1000,))
    
    def vector_sum(x, y):
        res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}, vmap_method="sequential")
        return res["vector_add"]["result"].sum()
    
    vector_sum(x, y) # success!
    
    # You can also use it with JAX transformations like JIT and grad
    vector_sum_jit = jax.jit(vector_sum)
    vector_sum_jit(x, y)
    
    vector_sum_grad = jax.grad(vector_sum)
    vector_sum_grad(x, y)
    
    # vmap requires an explicit vmap_method — "sequential" is safe but slow
    # while "auto_experimental" or "expand_dims" is more efficient for Tesseracts that support batching.
    vector_sum_vmap = jax.vmap(vector_sum)
    vector_sum_vmap(x.reshape(10, 100), y.reshape(10, 100))
    

See also

See Batching strategies for jax.vmap for a guide on selecting the appropriate vmap_method.

Tip

Now you’re ready to jump into our examples for ways to use Tesseract-JAX.

Sharp edges

  • Additional required endpoints: Tesseract-JAX requires the abstract_eval Tesseract endpoint to be defined when used in conjunction with automatic differentiation and JAX transformations. This is because JAX, in these cases, mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like jax.grad require vector_jacobian_product to be defined.

Tip

When creating a new Tesseract based on a JAX function, use tesseract init --recipe jax to define all required endpoints automatically, including abstract_eval and vector_jacobian_product.

  • Non-differentiable inputs/outputs: Differentiating through inputs or outputs not marked as Differentiable[...] in the Tesseract schema can raise a ValueError or produce NaN tangents. See the Handling Differentiability page for details and workarounds.

  • No JAX operations inside from_tesseract_api endpoints: When using Tesseract.from_tesseract_api(...), the apply, vector_jacobian_product, and jacobian_vector_product functions in your tesseract_api.py execute inside JAX FFI callbacks. Using jax.numpy or any other JAX operation that allocates arrays in these functions can cause deadlocks, because JAX’s runtime is already holding a lock during the callback.

    Use plain NumPy instead:

    # ❌ Bad — will deadlock under jit/grad
    import jax.numpy as jnp
    
    def apply(inputs):
        return OutputSchema(c=jnp.sin(inputs.a))
    
    # ✅ Good — use numpy for in-process Tesseracts
    import numpy as np
    
    def apply(inputs):
        return OutputSchema(c=np.sin(inputs.a))
    

    Note

    This only affects from_tesseract_api (in-process execution). Tesseracts served via Docker (from_image) run in a separate process and are not subject to this restriction.