\n",
"\n",
"All examples are expected to run from the `examples/` directory of the [Tesseract-JAX repository](https://github.com/pasteurlabs/tesseract-jax).\n",
"
\n",
"\n",
"Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.\n",
"\n",
"In this example, you will learn how to:\n",
"1. Build a Tesseract that performs vector addition.\n",
"1. Access its endpoints via Tesseract-JAX's `apply_tesseract()` function.\n",
"1. Compose Tesseracts into more complex functions, blending multiple Tesseract applications with local operations.\n",
"2. Apply `jax.jit` to the resulting pipeline to perform JIT compilation, and / or autodifferentiate the function (via `jax.grad`, `jax.jvp`, `jax.vjp`, ...)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Build + serve example Tesseract\n",
"\n",
"In this example, we build and use a Tesseract that performs vector addition. The example Tesseract takes two vectors and scalars as input and return some statistics as output. Here is the functionality that's implemented in the Tesseract (see `vectoradd_jax/tesseract_api.py`):\n",
"\n",
"```python\n",
"def apply_jit(inputs: dict) -> dict:\n",
" a_scaled = inputs[\"a\"][\"s\"] * inputs[\"a\"][\"v\"]\n",
" b_scaled = inputs[\"b\"][\"s\"] * inputs[\"b\"][\"v\"]\n",
" add_result = a_scaled + b_scaled\n",
" min_result = a_scaled - b_scaled\n",
"\n",
" def safe_norm(x, ord):\n",
" # Compute the norm of a vector, adding a small epsilon to ensure\n",
" # differentiability and avoid division by zero\n",
" return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord)\n",
"\n",
" return {\n",
" \"vector_add\": {\n",
" \"result\": add_result,\n",
" \"normed_result\": add_result / safe_norm(add_result, ord=inputs[\"norm_ord\"]),\n",
" },\n",
" \"vector_min\": {\n",
" \"result\": min_result,\n",
" \"normed_result\": min_result / safe_norm(min_result, ord=inputs[\"norm_ord\"]),\n",
" },\n",
" }\n",
"```\n",
"\n",
"You may build the example Tesseract either via the command line, or running the cell below (you can skip running this if already built)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2K \u001b[1;2m[\u001b[0m\u001b[34mi\u001b[0m\u001b[1;2m]\u001b[0m Building image \u001b[33m...\u001b[0m\n",
"\u001b[2K\u001b[37m⠙\u001b[0m \u001b[37mProcessing\u001b[0m\n",
"\u001b[1A\u001b[2K \u001b[1;2m[\u001b[0m\u001b[34mi\u001b[0m\u001b[1;2m]\u001b[0m Built image sh\u001b[1;92ma256:7ae8\u001b[0m5ba85970, \u001b[1m[\u001b[0m\u001b[32m'vectoradd_jax:latest'\u001b[0m\u001b[1m]\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"vectoradd_jax:latest\"]\n"
]
}
],
"source": [
"%%bash\n",
"# Build vectoradd_jax Tesseract so we can use it below\n",
"tesseract build vectoradd_jax/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To interact with the Tesseract, we use the Python SDK from `tesseract_core` to load the built image and start a server container."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from tesseract_core import Tesseract\n",
"\n",
"vectoradd = Tesseract.from_image(\"vectoradd_jax\")\n",
"vectoradd.serve()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: Invoke the Tesseract via Tesseract-JAX\n",
"\n",
"Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together, representing the following operation:\n",
"\n",
"$$\\begin{pmatrix} 1 \\\\ 2 \\\\ 3 \\end{pmatrix} + 2 \\cdot \\begin{pmatrix} 4 \\\\ 5 \\\\ 6 \\end{pmatrix} = \\begin{pmatrix} 9 \\\\ 12 \\\\ 15 \\end{pmatrix}$$\n",
"\n",
"We can perform this calculation using the function `tesseract_jax.apply_tesseract()`, by passing the `Tesseract` object and the input data as a PyTree (nested dictionary) of JAX arrays as inputs."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'vector_add': {'normed_result': Array([0.42426407, 0.56568545, 0.70710677], dtype=float32),\n",
" 'result': Array([ 9., 12., 15.], dtype=float32)},\n",
" 'vector_min': {'normed_result': Array([-0.5025707 , -0.5743665 , -0.64616233], dtype=float32),\n",
" 'result': Array([-7., -8., -9.], dtype=float32)}}\n"
]
}
],
"source": [
"from pprint import pprint\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from tesseract_jax import apply_tesseract\n",
"\n",
"a = {\"v\": jnp.array([1.0, 2.0, 3.0], dtype=\"float32\")}\n",
"b = {\n",
" \"v\": jnp.array([4.0, 5.0, 6.0], dtype=\"float32\"),\n",
" \"s\": jnp.array(2.0, dtype=\"float32\"),\n",
"}\n",
"\n",
"outputs = apply_tesseract(vectoradd, inputs={\"a\": a, \"b\": b})\n",
"pprint(outputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As expected, `outputs['vector_add']` gives a value of $(9, 12, 15)$."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3: Function composition via Tesseracts\n",
"\n",
"Tesseract-JAX enables you to compose chains of Tesseract evaluations, blended with local operations, while retaining all the benefits of JAX.\n",
"\n",
"The function below applies `vectoradd` twice, *ie.* $(\\mathbf{a} + \\mathbf{b}) + \\mathbf{a}$, then performs local arithmetic on the outputs, applies `vectoradd` once more, and finally returns a single element of the result. The resulting function is still a valid JAX function, and is fully jittable and auto-differentiable."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(16.135319, dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fancy_operation(a: jax.Array, b: jax.Array) -> jnp.float32:\n",
" \"\"\"Fancy operation.\"\"\"\n",
" result = apply_tesseract(vectoradd, inputs={\"a\": a, \"b\": b})\n",
" result = apply_tesseract(\n",
" vectoradd, inputs={\"a\": {\"v\": result[\"vector_add\"][\"result\"]}, \"b\": b}\n",
" )\n",
" # We can mix and match with local JAX operations\n",
" result = 2.0 * result[\"vector_add\"][\"normed_result\"] + b[\"v\"]\n",
" result = apply_tesseract(vectoradd, inputs={\"a\": {\"v\": result}, \"b\": b})\n",
" return result[\"vector_add\"][\"result\"][1]\n",
"\n",
"\n",
"fancy_operation(a, b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is compatible with `jax.jit()`:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(16.135319, dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jitted_op = jax.jit(fancy_operation)\n",
"jitted_op(a, b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Autodifferentiation is automatically dispatched to the underlying Tesseract's `jacobian_vector_product` and `vector_jacobian_product` endpoints, and works as expected:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jax.grad result:\n",
"({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n",
" {'s': Array(5.002062, dtype=float32),\n",
" 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})\n",
"\n",
"jax.jvp result:\n",
"Array(25.004124, dtype=float32)\n",
"\n",
"jax.vjp result:\n",
"({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n",
" {'s': Array(5.002062, dtype=float32),\n",
" 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})\n"
]
}
],
"source": [
"# jax.grad for reverse-mode autodiff (scalar outputs only)\n",
"grad_res = jax.grad(fancy_operation, argnums=[0, 1])(a, b)\n",
"print(\"jax.grad result:\")\n",
"pprint(grad_res)\n",
"\n",
"# jax.jvp for general forward-mode autodiff\n",
"_, jvp = jax.jvp(fancy_operation, (a, b), (a, b))\n",
"print(\"\\njax.jvp result:\")\n",
"pprint(jvp)\n",
"\n",
"# jax.vjp for general reverse-mode autodiff\n",
"_, vjp_fn = jax.vjp(fancy_operation, a, b)\n",
"vjp = vjp_fn(1.0)\n",
"print(\"\\njax.vjp result:\")\n",
"pprint(vjp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All the above also works when combining with `jit`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jax.grad result:\n",
"({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n",
" {'s': Array(5.002062, dtype=float32),\n",
" 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})\n"
]
}
],
"source": [
"# jax.grad for reverse-mode autodiff (scalar output)\n",
"grad_res = jax.jit(jax.grad(fancy_operation, argnums=[0, 1]))(a, b)\n",
"print(\"jax.grad result:\")\n",
"pprint(grad_res)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step N+1: Clean-up and conclusions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we kept the Tesseract alive using `.serve()`, we need to manually stop it using `.teardown()` to avoid leaking resources. \n",
"\n",
"This is not necessary when using `Tesseract` in a `with` statement, as it will automatically clean up when the context is exited."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"vectoradd.teardown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And that's it! \n",
"You've learned how to build up differentiable pipelines with Tesseracts that blend seamlessly with JAX's APIs and transformations."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "science",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}