Univariate Rosenbrock function

Context

Example that wraps the univariate Rosenbrock function, which is a common test problem for optimization algorithms. Defines a Tesseract that has all optional endpoints implemented, including apply, abstract_eval, jacobian, jacobian_vector_product, and vector_jacobian_product.

See also

This example (and using it to perform optimization) is also part of an Expert Showcase in the Tesseract Community Forum.

Example Tesseract (examples/univariate)

Core functionality — schemas and apply function

This example uses a pure-Python implementation of the Rosenbrock function as the basis for all endpoints:

def rosenbrock(x: float, y: float, a: float = 1.0, b: float = 100.0) -> float:
    return (a - x) ** 2 + b * (y - x**2) ** 2

As such, the Tesseract has 2 differentiable scalar inputs (x and y) and a single output (the value of the Rosenbrock function at those inputs). The parameters a and b are treated as non-differentiable constants.

class InputSchema(BaseModel):
    x: Differentiable[Float32] = Field(description="Scalar value x.", default=0.0)
    y: Differentiable[Float32] = Field(description="Scalar value y.", default=0.0)
    a: Float32 = Field(description="Scalar parameter a.", default=1.0)
    b: Float32 = Field(description="Scalar parameter b.", default=100.0)
class OutputSchema(BaseModel):
    result: Differentiable[Float32] = Field(
        description="Result of Rosenbrock function evaluation."
    )

This makes it straightforward to write the apply function, which simply unpacks the inputs and calls the rosenbrock function with them:

def apply(inputs: InputSchema) -> OutputSchema:
    """Evaluates the Rosenbrock function given input values and parameters."""
    result = rosenbrock(inputs.x, inputs.y, a=inputs.a, b=inputs.b)
    return OutputSchema(result=result)

Jacobian endpoint

For the Jacobian, we exploit the fact that the rosenbrock function is traceable by JAX. We can therefore use jax.jacrev to compute the Jacobian of the function with respect to its inputs:

def jacobian(
    inputs: InputSchema,
    jac_inputs: set[str],
    jac_outputs: set[str],
):
    rosenbrock_signature = ["x", "y", "a", "b"]

    jac_result = {dy: {} for dy in jac_outputs}
    for dx in jac_inputs:
        grad_func = jax.jacrev(rosenbrock, argnums=rosenbrock_signature.index(dx))
        for dy in jac_outputs:
            jac_result[dy][dx] = grad_func(inputs.x, inputs.y, inputs.a, inputs.b)

    return jac_result

Other AD endpoints

We define the JVP (Jacobian-vector product) and VJP (vector-Jacobian product) endpoints by summing over rows / columns of the Jacobian matrix. That is, we call jacobian under the hood, then multiply the resulting Jacobian matrix by the (tangent / cotangent) vector input.

Warning

Defining JVP and VJP operations through sums over the full Jacobian matrix is inefficient and negates the benefits of using JVP / VJP. These endpoints are provided for completeness, but in practice, you would typically use JAX’s built-in JVP and VJP functions directly.

def jacobian_vector_product(
    inputs: InputSchema,
    jvp_inputs: set[str],
    jvp_outputs: set[str],
    tangent_vector,
):
    # NOTE: This is a naive implementation of JVP, which is not efficient.
    jac = jacobian(inputs, jvp_inputs, jvp_outputs)
    out = {}
    for dy in jvp_outputs:
        out[dy] = sum(jac[dy][dx] * tangent_vector[dx] for dx in jvp_inputs)
    return out
def vector_jacobian_product(
    inputs: InputSchema,
    vjp_inputs: set[str],
    vjp_outputs: set[str],
    cotangent_vector,
):
    # NOTE: This is a naive implementation of VJP, which is not efficient.
    jac = jacobian(inputs, vjp_inputs, vjp_outputs)
    out = {}
    for dx in vjp_inputs:
        out[dx] = sum(jac[dy][dx] * cotangent_vector[dy] for dy in vjp_outputs)
    return out

Abstract evaluation

Some Tesseract clients (like Tesseract-JAX) require an abstract evaluation endpoint in order to pre-allocate memory for the inputs and outputs. This is a simple function that returns the shapes of the outputs based on the shapes of the inputs. In this case, the output is always a scalar, so we return an empty shape tuple:

def abstract_eval(abstract_inputs):
    """Calculate output shape of apply from the shape of its inputs."""
    return {"result": ShapeDType(shape=(), dtype="Float32")}