API reference

tesseract_jax.apply_tesseract(tesseract_client, inputs)[source]

Applies the given Tesseract object to the inputs.

This function is fully traceable and can be used in JAX transformations like jit, grad, etc. It will automatically dispatch to the appropriate Tesseract endpoint based on the requested operation.

Example

>>> from tesseract_core import Tesseract
>>> from tesseract_jax import apply_tesseract
>>>
>>> # Create a Tesseract object and some inputs
>>> tesseract_client = Tesseract.from_image("univariate")
>>> tesseract_client.serve()
>>> inputs = {"x": jax.numpy.array(1.0), "y": jax.numpy.array(2.0)}
>>>
>>> # Apply the Tesseract object to the inputs
>>> # (this calls tesseract_client.apply under the hood)
>>> apply_tesseract(tesseract_client, inputs)
{'result': Array(100., dtype=float64)}
>>>
>>> # Compute the gradient of the outputs with respect to the inputs
>>> # (this calls tesseract_client.vector_jacobian_product under the hood)
>>> def apply_fn(x):
...     res = apply_tesseract(tesseract_client, x)
...     return res["result"].sum()
>>> grad_fn = jax.grad(apply_fn)
>>> grad_fn(inputs)
{'x': Array(-400., dtype=float64, weak_type=True), 'y': Array(200., dtype=float64, weak_type=True)}
Parameters:
  • tesseract_client (Tesseract) – The Tesseract object to apply.

  • inputs (Any) – The inputs to apply to the Tesseract object.

Return type:

Any

Returns:

The outputs of the Tesseract object after applying the inputs.