Source code for tesseract_core.sdk.api_parse

# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import ast
from pathlib import Path
from typing import Annotated, Literal, NamedTuple, Union

import yaml
from pydantic import (
    AfterValidator,
    BaseModel,
    BeforeValidator,
    ConfigDict,
    Field,
    Strict,
)
from pydantic import ValidationError as PydanticValidationError


class _ApiObject(NamedTuple):
    name: str
    expected_type: type
    num_args: int | None = None
    arg_names: tuple[str, ...] | None = None
    optional: bool = False


ORDINALS = ["first", "second", "third", "fourth", "fifth", "sixth", "seventh", "eighth"]

EXPECTED_OBJECTS = (
    _ApiObject("apply", ast.FunctionDef, 1, arg_names=("inputs",)),
    _ApiObject("InputSchema", ast.ClassDef),
    _ApiObject("OutputSchema", ast.ClassDef),
    _ApiObject(
        "jacobian",
        ast.FunctionDef,
        3,
        arg_names=("inputs", "jac_inputs", "jac_outputs"),
        optional=True,
    ),
    _ApiObject(
        "jacobian_vector_product",
        ast.FunctionDef,
        4,
        arg_names=("inputs", "jvp_inputs", "jvp_outputs", "tangent_vector"),
        optional=True,
    ),
    _ApiObject(
        "vector_jacobian_product",
        ast.FunctionDef,
        4,
        arg_names=("inputs", "vjp_inputs", "vjp_outputs", "cotangent_vector"),
        optional=True,
    ),
    _ApiObject(
        "abstract_eval",
        ast.FunctionDef,
        1,
        arg_names=("abstract_inputs",),
        optional=True,
    ),
)


def assert_relative_path(value: str) -> str:
    """Assert that a string encodes a relative path."""
    if Path(value).is_absolute():
        raise ValueError(f"value must be a relative path (got {value})")
    return value


RelativePath = Annotated[str, AfterValidator(assert_relative_path)]
StrictStr = Annotated[str, Strict()]


class PipRequirements(BaseModel):
    """Configuration options for Python environments built via pip."""

    provider: Literal["python-pip"]
    _filename: Literal["tesseract_requirements.txt"] = "tesseract_requirements.txt"
    _build_script: Literal["build_pip_venv.sh"] = "build_pip_venv.sh"
    model_config: ConfigDict = ConfigDict(extra="forbid")


class CondaRequirements(BaseModel):
    """Configuration options for Python environments built via conda."""

    provider: Literal["conda"]
    _filename: Literal["tesseract_environment.yaml"] = "tesseract_environment.yaml"
    _build_script: Literal["build_conda_venv.sh"] = "build_conda_venv.sh"
    model_config: ConfigDict = ConfigDict(extra="forbid")


PythonRequirements = Union[PipRequirements, CondaRequirements]


[docs] class TesseractBuildConfig(BaseModel): """Configuration options for building a Tesseract.""" base_image: StrictStr = Field( "debian:bookworm-slim", description="Base Docker image for the build. Must be Debian-based.", ) target_platform: StrictStr = Field( "native", description=( "Target platform for the Docker build. Must be a valid Docker platform, " "or 'native' to build for the host platform. " "In general, images built for one platform will not run on another." ), ) extra_packages: tuple[StrictStr, ...] = Field( (), description="Extra packages to install during build via apt-get." ) package_data: tuple[tuple[RelativePath, StrictStr], ...] | None = Field( (), description=( "Additional files to copy into the Docker image, in the format ``(source, destination)``. " "Source paths are relative to the Tesseract directory." ), ) custom_build_steps: tuple[StrictStr, ...] | None = Field( (), description=( "Custom steps to run during ``docker build`` (after everything else is installed). " "Example: ``[\"RUN echo 'Hello, world!'\"]``" ), ) requirements: PythonRequirements = PipRequirements(provider="python-pip") model_config = ConfigDict(extra="forbid")
# Allow None to be passed as a valid value for build_config, for example in YAML that comments out all options. OptionalBuildConfig = Annotated[ TesseractBuildConfig, BeforeValidator(lambda v: TesseractBuildConfig() if v is None else v), ]
[docs] class TesseractConfig(BaseModel): """Configuration options for Tesseracts. Defines valid options in ``tesseract_config.yaml``.""" name: StrictStr = Field(..., description="Name of the Tesseract.") version: StrictStr = Field("0+unknown", description="Version of the Tesseract.") description: StrictStr = Field( "", description="Free-text description of what the Tesseract does.", ) build_config: OptionalBuildConfig = Field( default_factory=TesseractBuildConfig, description="Configuration options for building the Tesseract.", ) model_config = ConfigDict(extra="forbid")
class ValidationError(Exception): """Raised when inputs needed to build a tesseract are invalid.""" pass def _get_func_argnames(func: ast.FunctionDef) -> tuple[str, ...]: """Get the names of the arguments of a function node. See: https://docs.python.org/3/library/ast.html#ast.FunctionDef https://docs.python.org/3/library/ast.html#ast.arguments """ func_args = func.args if func_args.kwonlyargs: raise ValidationError( f"Function {func.name} must not have keyword-only arguments" ) if func_args.posonlyargs: raise ValidationError( f"Function {func.name} must not have positional-only arguments" ) return tuple(arg.arg for arg in func_args.args) def validate_tesseract_api(src_dir: Path) -> None: """Check that given folder contains a Tesseract API that satisfies our constraints. This function does not return anything, but it raises `ValidationError` in case something goes wrong. In particular, we are checking that: * The mandatory endpoints needed for a tesseract are actually implemented * The implemented functions have the correct signature * Both InputSchema and OutputSchema are `pydantic.BaseModel`s. Args: src_dir (Path): Path to the directory containing tesseract_api.py and tesseract_config.yaml. """ tesseract_api_location = src_dir / "tesseract_api.py" config_location = src_dir / "tesseract_config.yaml" if not tesseract_api_location.exists(): raise ValidationError(f"No file found at {tesseract_api_location}") if not config_location.exists(): raise ValidationError(f"No file found at {config_location}") # Validate config try: get_config(src_dir) except PydanticValidationError as err: raise ValidationError( f"Invalid configuration in {config_location}: {err}" ) from err # Parse Tesseract API with open(tesseract_api_location) as f: tesseract_api_code = f.read() try: tesseract_api = ast.parse(tesseract_api_code) except SyntaxError as err: raise ValidationError( f"Syntax error in {tesseract_api_location}: {err}" ) from err # Check if expected attributes are defined toplevel_objects = { node.name: node for node in tesseract_api.body if hasattr(node, "name") } for obj in EXPECTED_OBJECTS: if obj.name not in toplevel_objects: if obj.optional: continue raise ValidationError(f"{obj.name} not defined in {tesseract_api_location}") if not isinstance(toplevel_objects[obj.name], obj.expected_type): raise ValidationError( f"{obj.name} is not a {obj.expected_type.__name__} in {tesseract_api_location}" ) if obj.num_args is not None: func_argnames = _get_func_argnames(toplevel_objects[obj.name]) func_argnums = len(func_argnames) if func_argnums != obj.num_args: raise ValidationError( f"{obj.name} must have {obj.num_args} arguments: {', '.join(obj.arg_names)}.\n" f"However, {tesseract_api_location} specifies {func_argnums} " f"arguments: {', '.join(func_argnames)}." ) msgs = [] for i in range(obj.num_args): if func_argnames[i] != obj.arg_names[i]: msgs.append( f"The {ORDINALS[i]} argument (argument {i}) of {obj.name} must be named {obj.arg_names[i]}, " f"but {tesseract_api_location} has named it {func_argnames[i]}." ) if msgs: raise ValidationError("\n".join(msgs)) # Check InputSchema and OutputSchema are pydantic BaseModels for schema in ("InputSchema", "OutputSchema"): obj = toplevel_objects[schema] if not obj.bases: subclass = None else: subclass = obj.bases[0].id if subclass != "BaseModel": raise ValidationError( f"{schema} must be a subclass of pydantic.BaseModel (got: {subclass})" ) def get_config(src_dir: Path) -> TesseractConfig: """Get configuration options from a tesseract_config.yaml file.""" config_file = src_dir / "tesseract_config.yaml" if not config_file.exists(): raise FileNotFoundError(f"No file found at {config_file}") with open(config_file) as f: config = yaml.safe_load(f) try: return TesseractConfig(**config) except PydanticValidationError as err: raise ValidationError(f"Invalid configuration: {err}") from err def get_non_base_fields_in_tesseract_config() -> list[tuple[str, type]]: """Gets fields in Tesseract Config that are not a base fields.""" base_fields = (str, int, float, bool, bytes) non_base_fields = [] for field_name, field_info in TesseractConfig.model_fields.items(): if field_info.annotation not in base_fields: non_base_fields.append((field_name, field_info.annotation)) return non_base_fields