Source code for tesseract_torch.function

# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Differentiable PyTorch wrapper for Tesseract operations.

This module registers a Tesseract as a first-class differentiable primitive in
PyTorch's autograd graph.  The forward pass dispatches to ``tesseract.apply()``,
the backward pass to ``tesseract.vector_jacobian_product()``, and the
forward-mode JVP to ``tesseract.jacobian_vector_product()``.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import torch
from tesseract_core import Tesseract


def _to_tensor(arr: Any) -> torch.Tensor:
    """Convert a numpy array to a float32 tensor, copying if read-only."""
    a = np.asarray(arr)
    if not a.flags.writeable:
        a = a.copy()
    return torch.as_tensor(a)


def _tensor_to_numpy(t: torch.Tensor) -> np.ndarray:
    """Convert a torch tensor to a numpy array, preserving dtype."""
    # torch.func transforms (vjp, jvp, grad, vmap) wrap tensors in a C++
    # FunctionalTensorWrapper that has no backing storage.  These tensors
    # report type(t)==torch.Tensor (no Python subclass), so there is no
    # isinstance check we can use.  Instead we probe data_ptr(), the same
    # public precondition that .numpy() relies on, to raise an actionable
    # error instead of the confusing default message ("Cannot access data
    # pointer of Tensor that doesn't have storage").
    try:
        t.data_ptr()
    except RuntimeError:
        raise RuntimeError(
            "apply_tesseract does not support torch.func transforms "
            "(torch.func.vjp, torch.func.jvp, torch.func.grad, etc.). "
            "Use the standard autograd API instead:\n"
            "  - Reverse mode: result['y'].backward() or torch.autograd.grad()\n"
            "  - Forward mode: torch.autograd.forward_ad (dual tensors)"
        ) from None
    return t.detach().cpu().numpy()


def _get_differentiable_arrays(
    openapi_schema: dict,
    component: str,
) -> set[str]:
    """Extract differentiable array dotted-paths from the OpenAPI schema."""
    schema = openapi_schema["components"]["schemas"].get(component, {})
    return set(schema.get("differentiable_arrays", {}))


# ---------------------------------------------------------------------------
# Pytree helpers - flatten / unflatten nested dicts using dotted paths
# ---------------------------------------------------------------------------


def _flatten_pytree(
    tree: dict[str, Any],
    prefix: str = "",
    *,
    recurse_into: set[str] | None = None,
) -> list[tuple[str, Any]]:
    """Flatten a nested dict into ``(dotted_path, leaf_value)`` pairs.

    Only recurses into sub-dicts whose dotted prefix is a strict prefix of at
    least one path in *recurse_into*.  All other dicts are treated as opaque
    leaf values (e.g. ``dict[str, Array]`` schema fields).

    If *recurse_into* is ``None``, every nested dict is recursed into.
    """
    items: list[tuple[str, Any]] = []
    for key, value in tree.items():
        path = f"{prefix}.{key}" if prefix else key
        if isinstance(value, dict) and _should_recurse(path, value, recurse_into):
            items.extend(_flatten_pytree(value, path, recurse_into=recurse_into))
        else:
            items.append((path, value))
    return items


def _should_recurse(
    path: str,
    value: dict,
    known_paths: set[str] | None,
) -> bool:
    """Return True when *path* is a prefix of a known leaf path."""
    if not value:
        return False
    if known_paths is None:
        return True
    dot_prefix = path + "."
    return any(p.startswith(dot_prefix) for p in known_paths)


def _unflatten_pytree(flat: dict[str, Any]) -> dict[str, Any]:
    """Reconstruct a nested dict from ``{dotted_path: value}``."""
    tree: dict[str, Any] = {}
    for path, value in flat.items():
        parts = path.split(".")
        node = tree
        for part in parts[:-1]:
            node = node.setdefault(part, {})
        node[parts[-1]] = value
    return tree


# ---------------------------------------------------------------------------
# Core autograd function
# ---------------------------------------------------------------------------


class _TesseractFunction(torch.autograd.Function):
    """Low-level autograd function wrapping a Tesseract.

    This is an implementation detail.  Users should call :func:`apply_tesseract`.
    """

    @staticmethod
    def forward(
        tesseract: Tesseract,
        diff_input_names: list[str],
        diff_output_names: list[str],
        all_paths: set[str],
        static_inputs: dict[str, Any],
        non_diff_result_holder: list[dict[str, Any]],
        *tensors: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        """Run the Tesseract forward pass, returning differentiable outputs.

        The full (flat) result dict is stashed in *non_diff_result_holder*
        so the caller can reconstruct non-differentiable outputs without a
        second ``apply()`` call.
        """
        flat_inputs = dict(static_inputs)
        for name, tensor in zip(diff_input_names, tensors, strict=True):
            flat_inputs[name] = _tensor_to_numpy(tensor)

        result = tesseract.apply(_unflatten_pytree(flat_inputs))
        flat_result = dict(_flatten_pytree(result, recurse_into=all_paths))

        # Stash full result for the caller
        non_diff_result_holder.append(flat_result)

        # Return only the differentiable output tensors (in sorted order)
        return tuple(_to_tensor(flat_result[name]) for name in diff_output_names)

    @staticmethod
    def setup_context(
        ctx: Any,
        inputs: tuple[Any, ...],
        outputs: tuple[torch.Tensor, ...],
    ) -> None:
        """Save forward-pass metadata for use in backward / jvp."""
        (
            tesseract,
            diff_input_names,
            diff_output_names,
            all_paths,  # noqa: RUF059
            static_inputs,
            _holder,
            *tensors,
        ) = inputs
        ctx.tesseract = tesseract
        ctx.diff_input_names = diff_input_names
        ctx.diff_output_names = diff_output_names

        saved_inputs: dict[str, Any] = dict(static_inputs)
        for name, tensor in zip(diff_input_names, tensors, strict=True):
            saved_inputs[name] = _tensor_to_numpy(tensor)
        ctx.saved_inputs = saved_inputs

    @staticmethod
    def backward(
        ctx: Any,
        *grad_outputs: torch.Tensor,
    ) -> tuple[None | torch.Tensor, ...]:
        """Reverse-mode AD via the Tesseract's VJP endpoint."""
        cotangent_vector = {
            name: _tensor_to_numpy(grad)
            for name, grad in zip(ctx.diff_output_names, grad_outputs, strict=True)
        }

        vjp_result = ctx.tesseract.vector_jacobian_product(
            inputs=_unflatten_pytree(ctx.saved_inputs),
            vjp_inputs=list(ctx.diff_input_names),
            vjp_outputs=list(ctx.diff_output_names),
            cotangent_vector=cotangent_vector,
        )

        grad_inputs: list[torch.Tensor | None] = []
        for name in ctx.diff_input_names:
            g = vjp_result.get(name)
            grad_inputs.append(_to_tensor(g) if g is not None else None)

        # None for (tesseract, diff_input_names, diff_output_names,
        #           all_paths, static_inputs, non_diff_result_holder)
        return (None, None, None, None, None, None, *grad_inputs)

    @staticmethod
    def jvp(
        ctx: Any,
        *tangents: torch.Tensor | None,
    ) -> tuple[torch.Tensor, ...]:
        """Forward-mode AD via the Tesseract's JVP endpoint."""
        # tangents: (tesseract, diff_input_names, diff_output_names,
        #            all_paths, static_inputs, holder, *tensor_tangents)
        tensor_tangents = tangents[6:]

        tangent_vector: dict[str, Any] = {}
        jvp_inputs: list[str] = []
        for name, t in zip(ctx.diff_input_names, tensor_tangents, strict=True):
            if t is not None:
                tangent_vector[name] = _tensor_to_numpy(t)
                jvp_inputs.append(name)

        if not jvp_inputs:
            return tuple(
                torch.zeros_like(_to_tensor(ctx.saved_inputs.get(name, 0.0)))
                for name in ctx.diff_output_names
            )

        jvp_result = ctx.tesseract.jacobian_vector_product(
            inputs=_unflatten_pytree(ctx.saved_inputs),
            jvp_inputs=jvp_inputs,
            jvp_outputs=list(ctx.diff_output_names),
            tangent_vector=tangent_vector,
        )

        return tuple(_to_tensor(jvp_result[name]) for name in ctx.diff_output_names)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] def apply_tesseract( tesseract: Tesseract, inputs: dict[str, Any], ) -> dict[str, Any]: """Call a Tesseract as a differentiable PyTorch operation. Infers which inputs/outputs are differentiable from the Tesseract's schema. Torch tensors provided for differentiable fields participate in autograd; all other values are passed through as static inputs. Supports both reverse-mode (``.backward()``) and forward-mode (``torch.autograd.forward_ad``) differentiation. Args: tesseract: A Tesseract instance. inputs: Nested dict matching the Tesseract's input schema. Provide ``torch.Tensor`` for array fields you want gradients through, and plain Python / NumPy values for everything else. Returns: Nested dict matching the Tesseract's output schema, with differentiable array outputs as ``torch.Tensor`` (with ``grad_fn`` when inputs require grad) and non-differentiable outputs as-is (NumPy arrays or scalars). Example:: # Flat schema result = apply_tesseract(quadratic, {"x": x, "A": A, "b": b}) result["y"].sum().backward() # Nested schema result = apply_tesseract(meshstats, { "mesh": {"n_points": 3, ..., "points": points_tensor} }) result["statistics"]["barycenter"].sum().backward() """ openapi = tesseract.openapi_schema diff_in_paths = _get_differentiable_arrays(openapi, "ApplyInputSchema") diff_out_paths = _get_differentiable_arrays(openapi, "ApplyOutputSchema") diff_out_names = sorted(diff_out_paths) # All known dotted paths guide pytree flattening so we recurse into # sub-models but not into opaque dict fields. all_paths = diff_in_paths | diff_out_paths flat_inputs = _flatten_pytree(inputs, recurse_into=all_paths) # Partition into differentiable tensors vs static values diff_names: list[str] = [] diff_tensors: list[torch.Tensor] = [] static: dict[str, Any] = {} for path, value in flat_inputs: if path in diff_in_paths and isinstance(value, torch.Tensor): diff_names.append(path) diff_tensors.append(value) elif isinstance(value, torch.Tensor): static[path] = _tensor_to_numpy(value) else: static[path] = value # Mutable holder so forward() can pass the full result dict back to us # without going through autograd's return values. result_holder: list[dict[str, Any]] = [] output_tensors = _TesseractFunction.apply( tesseract, diff_names, diff_out_names, all_paths, static, result_holder, *diff_tensors, ) # Reconstruct full output pytree flat_result = dict(result_holder[0]) for name, tensor in zip(diff_out_names, output_tensors, strict=True): flat_result[name] = tensor return _unflatten_pytree(flat_result)