Debugging pipelines¶
When building pipelines that chain multiple Tesseracts, it can be difficult to understand what values flow between steps — especially when gradients are involved. Tesseract-JAX provides sow() and save_intermediates() to help you inspect intermediate values and their derivatives.
Basic usage¶
Use sow() to tag any intermediate value with a name, then wrap your function with save_intermediates() to extract tagged values:
from tesseract_jax import apply_tesseract, sow, save_intermediates
def my_pipeline(inputs):
res = apply_tesseract(tess1, inputs)
res = sow(res, "after_tess1") # tag this intermediate
res = apply_tesseract(tess2, res)
return res["output"].sum()
result, intermediates = save_intermediates(my_pipeline)(inputs)
# intermediates["after_tess1"]["primal"] contains the forward-pass value
print(intermediates["after_tess1"]["primal"])
sow() acts as a pure identity function — it returns its input unchanged and has no effect on the computation. Tagged values are only captured when the function is wrapped with save_intermediates().
Capturing gradients¶
When save_intermediates() wraps a function that involves a gradient transformation, it captures derivatives alongside primal values automatically.
With jax.grad (cotangents)¶
import jax
grad_fn = jax.grad(my_pipeline)
grads, intermediates = save_intermediates(grad_fn)(inputs)
# Forward-pass value at the tagged point
intermediates["after_tess1"]["primal"]
# Cotangent (gradient flowing back through this point)
intermediates["after_tess1"]["cotangent"]
With jax.jvp (tangents)¶
def forward_with_jvp(x):
primals, tangents = jax.jvp(my_pipeline, (x,), (dx,))
return primals
result, intermediates = save_intermediates(forward_with_jvp)(inputs)
# Forward-pass value
intermediates["after_tess1"]["primal"]
# Tangent (derivative propagated forward through this point)
intermediates["after_tess1"]["tangent"]
With jax.vjp¶
def forward_with_vjp(x):
primals, f_vjp = jax.vjp(my_pipeline, x)
grads = f_vjp(jnp.ones_like(primals))
return grads[0]
result, intermediates = save_intermediates(forward_with_vjp)(inputs)
intermediates["after_tess1"]["primal"]
intermediates["after_tess1"]["cotangent"]
Multiple tagged values¶
You can tag as many intermediates as you like — each must have a unique name:
def my_pipeline(inputs):
res1 = apply_tesseract(tess1, inputs)
res1 = sow(res1, "after_tess1")
res2 = apply_tesseract(tess2, res1)
res2 = sow(res2, "after_tess2")
return res2["output"].sum()
grads, intermediates = save_intermediates(jax.grad(my_pipeline))(inputs)
# Inspect each step independently
print(intermediates["after_tess1"]["primal"])
print(intermediates["after_tess1"]["cotangent"])
print(intermediates["after_tess2"]["primal"])
print(intermediates["after_tess2"]["cotangent"])
Compatibility with jax.jit¶
sow() works inside jax.jit-compiled functions. save_intermediates() traces into JIT boundaries and captures intermediates correctly:
@jax.jit
def my_pipeline(inputs):
res = apply_tesseract(tess1, inputs)
res = sow(res, "after_tess1")
return res["output"].sum()
# Works as expected — intermediates are captured from inside jit
result, intermediates = save_intermediates(my_pipeline)(inputs)
Note
save_intermediates() should be the outermost transformation. It works by rewriting the function’s JAX program trace, so it needs to wrap everything else.
Summary of captured keys¶
The keys present in each intermediate’s dictionary depend on which JAX transformations are active:
Transformation |
Keys captured |
|---|---|
Plain call |
|
|
|
|
|
|
|
|
|