Source code for tesseract_core.sdk.tesseract

# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import atexit
import base64
import traceback
from collections.abc import Callable, Mapping, Sequence
from functools import cached_property, wraps
from pathlib import Path
from types import ModuleType
from typing import Any, Literal
from urllib.parse import urlparse, urlunparse

import numpy as np
import requests
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic_core import InitErrorDetails

from . import engine

PathLike = str | Path


def requires_client(func: Callable) -> Callable:
    """Decorator to require a client for a Tesseract instance."""

    @wraps(func)
    def wrapper(self: Tesseract, *args: Any, **kwargs: Any) -> Any:
        if not self._client:
            raise RuntimeError(
                f"When creating a {self.__class__.__name__} via `from_image`, "
                "you must either use it as a context manager or call .serve() before use."
            )
        return func(self, *args, **kwargs)

    return wrapper


[docs] class Tesseract: """A Tesseract. This class represents a single Tesseract instance, either remote or local, and provides methods to run commands on it and retrieve results. Communication between a Tesseract and this class is done either via HTTP requests or directly via Python calls to the Tesseract API. """ def __init__(self, url: str) -> None: self._spawn_config = None self._serve_context = None self._lastlog = None self._client = HTTPClient(url)
[docs] @classmethod def from_url(cls, url: str) -> Tesseract: """Create a Tesseract instance from a URL. This is useful for connecting to a remote Tesseract instance. Args: url: The URL of the Tesseract instance. Returns: A Tesseract instance. """ obj = cls.__new__(cls) obj.__init__(url) return obj
[docs] @classmethod def from_image( cls, image_name: str, *, host_ip: str = "127.0.0.1", port: str | None = None, network: str | None = None, network_alias: str | None = None, volumes: list[str] | None = None, environment: dict[str, str] | None = None, gpus: list[str] | None = None, num_workers: int = 1, user: str | None = None, input_path: str | Path | None = None, output_path: str | Path | None = None, output_format: Literal["json", "json+base64", "json+binref"] = "json", ) -> Tesseract: """Create a Tesseract instance from a Docker image. When using this method, the Tesseract will be spawned in a Docker container, serving the Tesseract API via HTTP. To use the Tesseract, you need to call the `serve` method or use it as a context manager. Example: >>> with Tesseract.from_image("my_tesseract") as t: ... # Use tesseract here This will automatically teardown the Tesseract when exiting the context manager. Args: image_name: Tesseract image name to serve. host_ip: IP address to bind the Tesseracts to. port: port or port range to serve each Tesseract on. network: name of the network the Tesseract will be attached to. network_alias: alias to use for the Tesseract within the network. volumes: list of paths to mount in the Tesseract container. environment: dictionary of environment variables to pass to the Tesseract. gpus: IDs of host Nvidia GPUs to make available to the Tesseracts. num_workers: number of workers to use for serving the Tesseracts. user: user to run the Tesseracts as, e.g. '1000' or '1000:1000' (uid:gid). Defaults to the current user. input_path: Input path to read input files from, such as local directory or S3 URI. output_path: Output path to write output files to, such as local directory or S3 URI. output_format: Format to use for the output data. Returns: A Tesseract instance. """ obj = cls.__new__(cls) if environment is None: environment = {} if volumes is None: volumes = [] if input_path is not None: input_path = Path(input_path).resolve() volumes.append(f"{input_path}:/tesseract/input_data:ro") if output_path is not None: output_path = Path(output_path).resolve() volumes.append(f"{output_path}:/tesseract/output_data:rw") obj._spawn_config = dict( image_name=image_name, volumes=volumes, environment=environment, gpus=gpus, num_workers=num_workers, network=network, network_alias=network_alias, user=user, input_path=input_path, output_path=output_path, output_format=output_format, port=port, host_ip=host_ip, debug=True, ) obj._serve_context = None obj._lastlog = None obj._client = None return obj
[docs] @classmethod def from_tesseract_api( cls, tesseract_api: str | Path | ModuleType, input_path: Path | None = None, output_path: Path | None = None, output_format: Literal["json", "json+base64", "json+binref"] = "json", ) -> Tesseract: """Create a Tesseract instance from a Tesseract API module. Warning: This does not use a containerized Tesseract, but rather imports the Tesseract API directly. This is useful for debugging, but requires a matching runtime environment + all dependencies to be installed locally. Args: tesseract_api: Path to the `tesseract_api.py` file, or an already imported Tesseract API module. input_path: Path of input directory. All paths in the tesseract payload have to be relative to this path. output_path: Path of output directory. All paths in the tesseract result with be given relative to this path. output_format: Format to use for the output data. Returns: A Tesseract instance. """ from tesseract_core.runtime.config import update_config if isinstance(tesseract_api, str | Path): from tesseract_core.runtime.core import load_module_from_path tesseract_api_path = Path(tesseract_api).resolve(strict=True) if not tesseract_api_path.is_file(): raise RuntimeError( f"Tesseract API path {tesseract_api_path} is not a file." ) try: tesseract_api = load_module_from_path(tesseract_api_path) except ImportError as ex: raise RuntimeError( f"Cannot load Tesseract API from {tesseract_api_path}" ) from ex if input_path is not None: update_config(input_path=str(input_path.resolve())) if output_path is not None: update_config(output_path=str(output_path.resolve())) update_config(output_format=output_format) obj = cls.__new__(cls) obj._spawn_config = None obj._serve_context = None obj._lastlog = None obj._client = LocalClient(tesseract_api) return obj
def __enter__(self) -> Tesseract: """Enter the Tesseract context. This will start the Tesseract server if it is not already running. """ if self._serve_context is not None: raise RuntimeError("Cannot serve the same Tesseract multiple times.") if self._client is not None: # Tesseract is already being served -> no-op return self self.serve() return self def __exit__(self, *args: Any) -> None: """Exit the Tesseract context. This will stop the Tesseract server if it is running. """ if self._serve_context is None: # This can happen if __enter__ short-circuits return self.teardown()
[docs] def server_logs(self) -> str: """Get the logs of the Tesseract server. Returns: logs of the Tesseract server. """ if self._spawn_config is None: raise RuntimeError( "Can only retrieve logs for a Tesseract created via from_image." ) if self._serve_context is None: return self._lastlog or "" return engine.logs(self._serve_context["container_name"])
[docs] def serve(self) -> None: """Serve the Tesseract until it is stopped.""" if self._spawn_config is None: raise RuntimeError("Can only serve a Tesseract created via from_image.") if self._serve_context is not None: raise RuntimeError("Tesseract is already being served.") container_name, container = engine.serve(**self._spawn_config) self._serve_context = dict( container_name=container_name, port=container.host_port, network=self._spawn_config["network"], network_alias=self._spawn_config["network_alias"], ) host_ip = self._spawn_config["host_ip"] self._lastlog = None self._client = HTTPClient(f"http://{host_ip}:{container.host_port}") atexit.register(self.teardown)
[docs] def teardown(self) -> None: """Teardown the Tesseract. This will stop and remove the Tesseract container. """ if self._serve_context is None: raise RuntimeError("Tesseract is not being served.") self._lastlog = self.server_logs() engine.teardown(self._serve_context["container_name"]) self._client = None self._serve_context = None atexit.unregister(self.teardown)
def __del__(self) -> None: """Destructor for the Tesseract class. This will teardown the Tesseract if it is being served. """ if self._serve_context is not None: self.teardown()
[docs] @cached_property @requires_client def openapi_schema(self) -> dict: """Get the OpenAPI schema of this Tesseract. Returns: dictionary with the OpenAPI Schema. """ return self._client.run_tesseract("openapi_schema")
@property @requires_client def available_endpoints(self) -> list[str]: """Get the list of available endpoints. Returns: a list with all available endpoints for this Tesseract. """ return [endpoint.lstrip("/") for endpoint in self.openapi_schema["paths"]]
[docs] @requires_client def apply(self, inputs: dict, run_id: str | None = None) -> dict: """Run apply endpoint. Args: inputs: a dictionary with the inputs. run_id: a string to identify the run. Run outputs will be located in a directory suffixed with this id. Returns: dictionary with the results. """ payload = {"inputs": inputs} return self._client.run_tesseract("apply", payload, run_id)
[docs] @requires_client def abstract_eval(self, abstract_inputs: dict) -> dict: """Run abstract eval endpoint. Args: abstract_inputs: a dictionary with the (abstract) inputs. Returns: dictionary with the results. """ payload = {"inputs": abstract_inputs} return self._client.run_tesseract("abstract_eval", payload)
[docs] @requires_client def health(self) -> dict: """Check the health of the Tesseract. Returns: dictionary with the health status. """ return self._client.run_tesseract("health")
[docs] @requires_client def jacobian( self, inputs: dict, jac_inputs: list[str], jac_outputs: list[str], run_id: str | None = None, ) -> dict: """Calculate the Jacobian of (some of the) outputs w.r.t. (some of the) inputs. Args: inputs: a dictionary with the inputs. jac_inputs: Inputs with respect to which derivatives will be calculated. jac_outputs: Outputs which will be differentiated. run_id: a string to identify the run. Run outputs will be located in a directory suffixed with this id. Returns: dictionary with the results. """ if "jacobian" not in self.available_endpoints: raise NotImplementedError("Jacobian not implemented for this Tesseract.") payload = { "inputs": inputs, "jac_inputs": jac_inputs, "jac_outputs": jac_outputs, } return self._client.run_tesseract("jacobian", payload, run_id)
[docs] @requires_client def jacobian_vector_product( self, inputs: dict, jvp_inputs: list[str], jvp_outputs: list[str], tangent_vector: dict, run_id: str | None = None, ) -> dict: """Calculate the Jacobian Vector Product (JVP) of (some of the) outputs w.r.t. (some of the) inputs. Args: inputs: a dictionary with the inputs. jvp_inputs: Inputs with respect to which derivatives will be calculated. jvp_outputs: Outputs which will be differentiated. tangent_vector: Element of the tangent space to multiply with the Jacobian. run_id: a string to identify the run. Run outputs will be located in a directory suffixed with this id. Returns: dictionary with the results. """ if "jacobian_vector_product" not in self.available_endpoints: raise NotImplementedError( "Jacobian Vector Product (JVP) not implemented for this Tesseract." ) payload = { "inputs": inputs, "jvp_inputs": jvp_inputs, "jvp_outputs": jvp_outputs, "tangent_vector": tangent_vector, } return self._client.run_tesseract("jacobian_vector_product", payload, run_id)
[docs] @requires_client def vector_jacobian_product( self, inputs: dict, vjp_inputs: list[str], vjp_outputs: list[str], cotangent_vector: dict, run_id: str | None = None, ) -> dict: """Calculate the Vector Jacobian Product (VJP) of (some of the) outputs w.r.t. (some of the) inputs. Args: inputs: a dictionary with the inputs. vjp_inputs: Inputs with respect to which derivatives will be calculated. vjp_outputs: Outputs which will be differentiated. cotangent_vector: Element of the cotangent space to multiply with the Jacobian. run_id: a string to identify the run. Run outputs will be located in a directory suffixed with this id. Returns: dictionary with the results. """ if "vector_jacobian_product" not in self.available_endpoints: raise NotImplementedError( "Vector Jacobian Product (VJP) not implemented for this Tesseract." ) payload = { "inputs": inputs, "vjp_inputs": vjp_inputs, "vjp_outputs": vjp_outputs, "cotangent_vector": cotangent_vector, } return self._client.run_tesseract("vector_jacobian_product", payload, run_id)
def _tree_map(func: Callable, tree: Any, is_leaf: Callable | None = None) -> Any: """Recursively apply a function to all leaves of a tree-like structure.""" if is_leaf is not None and is_leaf(tree): return func(tree) if isinstance(tree, Mapping): # Dictionary-like structure return {key: _tree_map(func, value, is_leaf) for key, value in tree.items()} if isinstance(tree, Sequence) and not isinstance( tree, (str, bytes) ): # List, tuple, etc. return type(tree)(_tree_map(func, item, is_leaf) for item in tree) # If nothing above matched do nothing return tree def _encode_array(arr: np.ndarray, b64: bool = True) -> dict: if b64: data = { "buffer": base64.b64encode(arr.tobytes()).decode(), "encoding": "base64", } else: data = { "buffer": arr.tolist(), "encoding": "raw", } return { "shape": arr.shape, "dtype": arr.dtype.name, "data": data, } def _decode_array(encoded_arr: dict) -> np.ndarray: if "data" in encoded_arr: if encoded_arr["data"]["encoding"] == "base64": data = base64.b64decode(encoded_arr["data"]["buffer"]) arr = np.frombuffer(data, dtype=encoded_arr["dtype"]) else: arr = np.array(encoded_arr["data"]["buffer"], dtype=encoded_arr["dtype"]) else: raise ValueError("Encoded array does not contain 'data' key. Cannot decode.") arr = arr.reshape(encoded_arr["shape"]) return arr class HTTPClient: """HTTP Client for Tesseracts.""" def __init__(self, url: str) -> None: self._url = self._sanitize_url(url) @staticmethod def _sanitize_url(url: str) -> str: parsed = urlparse(url) if not parsed.scheme: url = f"http://{url}" parsed = urlparse(url) sanitized = urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", "")) sanitized = sanitized.rstrip("/") return sanitized @property def url(self) -> str: """(Sanitized) URL to connect to.""" return self._url def _request( self, endpoint: str, method: str = "GET", payload: dict | None = None, run_id: str | None = None, ) -> dict: url = f"{self.url}/{endpoint.lstrip('/')}" if payload: encoded_payload = _tree_map( _encode_array, payload, is_leaf=lambda x: hasattr(x, "shape") ) else: encoded_payload = None params = {"run_id": run_id} if run_id is not None else {} response = requests.request( method=method, url=url, json=encoded_payload, params=params ) if response.status_code == requests.codes.unprocessable_entity: # Try and raise a more helpful error if the response is a Pydantic error try: data = response.json() except requests.JSONDecodeError: # Is not a Pydantic error data = {} if "detail" in data: errors = [] for e in data["detail"]: ctx = e.get("ctx", {}) if not ctx.get("error") and e.get("msg"): # Hacky, but msg contains info like "Value error, ...", # which will be prepended to the message anyway by pydantic. # This way, we remove whatever is before the first comma. msg = e["msg"].partition(", ")[2] ctx["error"] = msg error = InitErrorDetails( type=e["type"], loc=tuple(e["loc"]), input=e.get("input"), ctx=ctx, ) errors.append(error) raise ValidationError.from_exception_data( f"endpoint {endpoint}", line_errors=errors ) if not response.ok: raise RuntimeError( f"Error {response.status_code} from Tesseract: {response.text}" ) data = response.json() if endpoint in [ "apply", "jacobian", "jacobian_vector_product", "vector_jacobian_product", ]: data = _tree_map( _decode_array, data, is_leaf=lambda x: type(x) is dict and "shape" in x, ) return data def run_tesseract( self, endpoint: str, payload: dict | None = None, run_id: str | None = None ) -> dict: """Run a Tesseract endpoint. Args: endpoint: The endpoint to run. payload: The payload to send to the endpoint. run_id: a string to identify the run. Run outputs will be located in a directory suffixed with this id. Returns: The loaded JSON response from the endpoint, with decoded arrays. """ if endpoint in [ "openapi_schema", "health", ]: method = "GET" else: method = "POST" if endpoint == "openapi_schema": endpoint = "openapi.json" return self._request(endpoint, method, payload, run_id) class LocalClient: """Local Client for Tesseracts.""" def __init__(self, tesseract_api: ModuleType) -> None: from tesseract_core.runtime.core import create_endpoints from tesseract_core.runtime.serve import create_rest_api self._endpoints = { func.__name__: func for func in create_endpoints(tesseract_api) } self._openapi_schema = create_rest_api(tesseract_api).openapi() def run_tesseract( self, endpoint: str, payload: dict | None = None, run_id: str | None = None ) -> dict: """Run a Tesseract endpoint. Args: endpoint: The endpoint to run. payload: The payload to send to the endpoint. run_id: a string to identify the run. Returns: The loaded JSON response from the endpoint, with decoded arrays. """ if endpoint == "openapi_schema": return self._openapi_schema if endpoint not in self._endpoints: raise RuntimeError(f"Endpoint {endpoint} not found in Tesseract API.") func = self._endpoints[endpoint] InputSchema = func.__annotations__.get("payload", None) OutputSchema = func.__annotations__.get("return", None) if InputSchema is not None: parsed_payload = InputSchema.model_validate(payload) else: parsed_payload = None try: if parsed_payload is not None: result = self._endpoints[endpoint](parsed_payload) else: result = self._endpoints[endpoint]() except Exception as ex: # Some clients like Tesseract-JAX swallow tracebacks from re-raised exceptions, so we explicitly # format the traceback here to include it in the error message. tb = traceback.format_exc() raise RuntimeError( f"{tb}\nError running Tesseract API {endpoint}: {ex} (see above for full traceback)" ) from None if OutputSchema is not None: # Validate via schema, then dump to stay consistent with other clients if isinstance(OutputSchema, type) and issubclass(OutputSchema, BaseModel): result = OutputSchema.model_validate(result).model_dump() else: result = TypeAdapter(OutputSchema).validate_python(result) return result