API reference

tesseract_jax.apply_tesseract(tesseract_client, inputs, *, vmap_method=None)[source]

Applies the given Tesseract object to the inputs.

This function is fully traceable and can be used in JAX transformations like jit, grad, etc. It will automatically dispatch to the appropriate Tesseract endpoint based on the requested operation.

Scalar inputs (such as Python floats and ints) and objects implementing the __array__ protocol are automatically converted to JAX arrays where the Tesseract’s input schema expects arrays. Python sequences (lists, tuples) are rejected with a TypeError — convert them explicitly via jnp.array().

Example

>>> from tesseract_core import Tesseract
>>> from tesseract_jax import apply_tesseract
>>>
>>> # Create a Tesseract object and some inputs
>>> tesseract_client = Tesseract.from_image("univariate")
>>> tesseract_client.serve()
>>> inputs = {"x": jax.numpy.array(1.0), "y": jax.numpy.array(2.0)}
>>>
>>> # Apply the Tesseract object to the inputs
>>> # (this calls tesseract_client.apply under the hood)
>>> apply_tesseract(tesseract_client, inputs)
{'result': Array(100., dtype=float64)}
>>>
>>> # Scalar values are automatically converted to arrays
>>> apply_tesseract(tesseract_client, {"x": 1.0, "y": 2.0})
{'result': Array(100., dtype=float64)}
>>>
>>> # Compute the gradient of the outputs with respect to the inputs
>>> # (this calls tesseract_client.vector_jacobian_product under the hood)
>>> def apply_fn(x):
...     res = apply_tesseract(tesseract_client, x)
...     return res["result"].sum()
>>> grad_fn = jax.grad(apply_fn)
>>> grad_fn(inputs)
{'x': Array(-400., dtype=float64, weak_type=True), 'y': Array(200., dtype=float64, weak_type=True)}
Parameters:
  • tesseract_client (Tesseract) – The Tesseract object to apply.

  • inputs (Any) – The inputs to apply to the Tesseract object.

  • vmap_method (Optional[Literal['sequential', 'auto_experimental', 'expand_dims', 'broadcast_all']]) –

    Strategy for handling jax.vmap batching. Must be set explicitly when using jax.vmap; raises NotImplementedError if jax.vmap is applied with the default None.

    None (default)

    No vmap support. Raises NotImplementedError if jax.vmap is applied. All other JAX transforms (jit, grad) work normally.

    "sequential"

    Calls the Tesseract once per batch element via jax.lax.map. Safe for all Tesseracts regardless of schema.

    "auto_experimental"

    Experimental. Inspects the differentiable input schema at trace time. When all batched differentiable inputs use Array[..., dtype] (ellipsis shape), adds a leading (1,) dim to unbatched args and sends a single batched call. Falls back to sequential otherwise. Only considers differentiable inputs; non-differentiable array inputs are not yet supported.

    "expand_dims"

    Adds a leading (1,) dimension to every unbatched array arg and sends a single batched call. The Tesseract must broadcast (1, ...) against (batch, ...) internally. Use this when the Tesseract accepts a leading batch dimension on all inputs.

    "broadcast_all"

    Broadcasts every unbatched array arg to (batch, ...), so all args share the same leading dimension. Use this when the Tesseract requires all inputs to have identical shapes.

    Python scalars (float, int, bool) are always static and are never batched regardless of the chosen method. Scalar arrays (0-d, e.g. Float64) are treated as regular array args and will be transformed according to the method.

    See Batching strategies for jax.vmap for a detailed guide.

Return type:

Any

Returns:

The outputs of the Tesseract object after applying the inputs.

tesseract_jax.save_intermediates(fn, *, tag='intermediates')[source]

Functional transformation that captures values tagged by sow().

Returns a new function with the same signature as fn, but whose return value is a tuple (original_result, intermediates) where intermediates is a dictionary mapping sow names to sub-dictionaries with keys "primal", "tangent", and/or "cotangent".

Which keys are present depends on the JAX transformations applied to fn before wrapping with save_intermediates:

  • Plain call: only "primal"

  • jax.grad / jax.vjp: "primal" and "cotangent"

  • jax.jvp: "primal" and "tangent"

save_intermediates should be the outermost transformation. It recursively descends into sub-jaxprs (e.g. inside jax.jit boundaries) so sow calls inside JIT-compiled functions are captured correctly.

Parameters:
  • fn (Callable[..., Any]) – The function to wrap.

  • tag (str) – Only capture intermediates whose tag matches this string. Defaults to "intermediates".

Return type:

Callable[..., tuple[Any, dict[str, dict[str, Any]]]]

Returns:

A new callable (*args, **kwargs) -> (result, intermediates).

Raises:

ValueError – If a sow name is used more than once inside fn.

tesseract_jax.sow(value, name, *, tag='intermediates')[source]

Tag an intermediate value for capture by save_intermediates().

Acts as the identity function: the return value is always equal to value. When the enclosing function is later wrapped with save_intermediates(), the tagged value (and, if a derivative transformation is active, its tangent or cotangent) will appear in the returned intermediates dictionary.

Parameters:
  • value (TypeVar(T)) – Any JAX-compatible pytree (dict, list, array, nested combinations, …).

  • name (str) – A unique string name used to identify this intermediate in the dictionary returned by save_intermediates(). Using the same name twice inside a single function raises ValueError at save_intermediates time.

  • tag (str) – An optional string tag for grouping intermediates. Only intermediates whose tag matches the one passed to save_intermediates() will be captured. Defaults to "intermediates".

Return type:

TypeVar(T)

Returns:

value, unchanged.