API reference¶
- tesseract_jax.apply_tesseract(tesseract_client, inputs)[source]¶
Applies the given Tesseract object to the inputs.
This function is fully traceable and can be used in JAX transformations like jit, grad, etc. It will automatically dispatch to the appropriate Tesseract endpoint based on the requested operation.
Example
>>> from tesseract_core import Tesseract >>> from tesseract_jax import apply_tesseract >>> >>> # Create a Tesseract object and some inputs >>> tesseract_client = Tesseract.from_image("univariate") >>> tesseract_client.serve() >>> inputs = {"x": jax.numpy.array(1.0), "y": jax.numpy.array(2.0)} >>> >>> # Apply the Tesseract object to the inputs >>> # (this calls tesseract_client.apply under the hood) >>> apply_tesseract(tesseract_client, inputs) {'result': Array(100., dtype=float64)} >>> >>> # Compute the gradient of the outputs with respect to the inputs >>> # (this calls tesseract_client.vector_jacobian_product under the hood) >>> def apply_fn(x): ... res = apply_tesseract(tesseract_client, x) ... return res["result"].sum() >>> grad_fn = jax.grad(apply_fn) >>> grad_fn(inputs) {'x': Array(-400., dtype=float64, weak_type=True), 'y': Array(200., dtype=float64, weak_type=True)}