API reference¶
- tesseract_jax.apply_tesseract(tesseract_client, inputs, *, vmap_method=None)[source]¶
Applies the given Tesseract object to the inputs.
This function is fully traceable and can be used in JAX transformations like jit, grad, etc. It will automatically dispatch to the appropriate Tesseract endpoint based on the requested operation.
Scalar inputs (such as Python floats and ints) and objects implementing the
__array__protocol are automatically converted to JAX arrays where the Tesseract’s input schema expects arrays. Python sequences (lists, tuples) are rejected with aTypeError— convert them explicitly viajnp.array().Example
>>> from tesseract_core import Tesseract >>> from tesseract_jax import apply_tesseract >>> >>> # Create a Tesseract object and some inputs >>> tesseract_client = Tesseract.from_image("univariate") >>> tesseract_client.serve() >>> inputs = {"x": jax.numpy.array(1.0), "y": jax.numpy.array(2.0)} >>> >>> # Apply the Tesseract object to the inputs >>> # (this calls tesseract_client.apply under the hood) >>> apply_tesseract(tesseract_client, inputs) {'result': Array(100., dtype=float64)} >>> >>> # Scalar values are automatically converted to arrays >>> apply_tesseract(tesseract_client, {"x": 1.0, "y": 2.0}) {'result': Array(100., dtype=float64)} >>> >>> # Compute the gradient of the outputs with respect to the inputs >>> # (this calls tesseract_client.vector_jacobian_product under the hood) >>> def apply_fn(x): ... res = apply_tesseract(tesseract_client, x) ... return res["result"].sum() >>> grad_fn = jax.grad(apply_fn) >>> grad_fn(inputs) {'x': Array(-400., dtype=float64, weak_type=True), 'y': Array(200., dtype=float64, weak_type=True)}
- Parameters:
tesseract_client (
Tesseract) – The Tesseract object to apply.inputs (
Any) – The inputs to apply to the Tesseract object.vmap_method (
Optional[Literal['sequential','auto_experimental','expand_dims','broadcast_all']]) –Strategy for handling
jax.vmapbatching. Must be set explicitly when usingjax.vmap; raisesNotImplementedErrorifjax.vmapis applied with the defaultNone.None(default)No vmap support. Raises
NotImplementedErrorifjax.vmapis applied. All other JAX transforms (jit, grad) work normally."sequential"Calls the Tesseract once per batch element via
jax.lax.map. Safe for all Tesseracts regardless of schema."auto_experimental"Experimental. Inspects the differentiable input schema at trace time. When all batched differentiable inputs use
Array[..., dtype](ellipsis shape), adds a leading(1,)dim to unbatched args and sends a single batched call. Falls back to sequential otherwise. Only considers differentiable inputs; non-differentiable array inputs are not yet supported."expand_dims"Adds a leading
(1,)dimension to every unbatched array arg and sends a single batched call. The Tesseract must broadcast(1, ...)against(batch, ...)internally. Use this when the Tesseract accepts a leading batch dimension on all inputs."broadcast_all"Broadcasts every unbatched array arg to
(batch, ...), so all args share the same leading dimension. Use this when the Tesseract requires all inputs to have identical shapes.
Python scalars (
float,int,bool) are always static and are never batched regardless of the chosen method. Scalar arrays (0-d, e.g.Float64) are treated as regular array args and will be transformed according to the method.See Batching strategies for jax.vmap for a detailed guide.
- Return type:
- Returns:
The outputs of the Tesseract object after applying the inputs.
- tesseract_jax.save_intermediates(fn, *, tag='intermediates')[source]¶
Functional transformation that captures values tagged by
sow().Returns a new function with the same signature as fn, but whose return value is a tuple
(original_result, intermediates)where intermediates is a dictionary mapping sow names to sub-dictionaries with keys"primal","tangent", and/or"cotangent".Which keys are present depends on the JAX transformations applied to fn before wrapping with
save_intermediates:Plain call: only
"primal"jax.grad/jax.vjp:"primal"and"cotangent"jax.jvp:"primal"and"tangent"
save_intermediatesshould be the outermost transformation. It recursively descends into sub-jaxprs (e.g. insidejax.jitboundaries) sosowcalls inside JIT-compiled functions are captured correctly.- Parameters:
- Return type:
- Returns:
A new callable
(*args, **kwargs) -> (result, intermediates).- Raises:
ValueError – If a sow name is used more than once inside fn.
- tesseract_jax.sow(value, name, *, tag='intermediates')[source]¶
Tag an intermediate value for capture by
save_intermediates().Acts as the identity function: the return value is always equal to value. When the enclosing function is later wrapped with
save_intermediates(), the tagged value (and, if a derivative transformation is active, its tangent or cotangent) will appear in the returned intermediates dictionary.- Parameters:
value (
TypeVar(T)) – Any JAX-compatible pytree (dict, list, array, nested combinations, …).name (
str) – A unique string name used to identify this intermediate in the dictionary returned bysave_intermediates(). Using the same name twice inside a single function raisesValueErrorat save_intermediates time.tag (
str) – An optional string tag for grouping intermediates. Only intermediates whose tag matches the one passed tosave_intermediates()will be captured. Defaults to"intermediates".
- Return type:
TypeVar(T)- Returns:
value, unchanged.