Source code for tesseract_core.runtime.schema_types

# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from abc import ABCMeta
from enum import IntEnum
from functools import partial
from typing import (
    Annotated,
    Any,
    Optional,
    Union,
    get_args,
    get_origin,
)

import numpy as np
from pydantic import (
    AfterValidator,
    BaseModel,
    ConfigDict,
    GetCoreSchemaHandler,
    GetJsonSchemaHandler,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema

from tesseract_core.runtime.array_encoding import (
    AllowedDtypes,
    decode_array,
    encode_array,
    get_array_model,
    python_to_array,
)

AnnotatedType = type(Annotated[Any, Any])
EllipsisType = type(Ellipsis)


def _ensure_valid_shapedtype(expected_shape: Any, expected_dtype: Any) -> tuple:
    if not isinstance(expected_shape, (tuple, EllipsisType)):
        raise ValueError(
            "Shape in Array[<shape>, <dtype>] must be a tuple or '...' (ellipsis)"
        )

    if isinstance(expected_shape, tuple):
        for dim in expected_shape:
            if dim is not None and not isinstance(dim, int):
                raise ValueError(
                    "Shape values in Array[<shape>, <dtype>] must be integers or None"
                )

    if is_array_annotation(expected_dtype):
        expected_dtype = expected_dtype.__metadata__[0].expected_dtype

    allowed_dtypes = get_args(AllowedDtypes)

    if expected_dtype not in allowed_dtypes and expected_dtype is not None:
        raise ValueError(
            f"Invalid dtype in Array[<shape>, <dtype>]: {expected_dtype} "
            f"(must be one of {allowed_dtypes} or a scalar Array type like, Array[(), Int32])"
        )
    return expected_shape, expected_dtype


[docs] class ShapeDType(BaseModel): """Data structure describing an array's shape and data type.""" shape: tuple[int, ...] dtype: AllowedDtypes # Ignore extra fields in the model, to allow encoded arrays to be passed model_config = ConfigDict(extra="ignore") def __class_getitem__( cls, key: tuple[ Union[tuple[Optional[int], ...], EllipsisType], Union[AnnotatedType, str, None], ], ) -> AnnotatedType: expected_shape, expected_dtype = _ensure_valid_shapedtype(*key) def validate(shapedtype: ShapeDType) -> ShapeDType: """Validator to check if the shape and dtype match the expected values.""" if isinstance(shapedtype, ShapeDType): shape = shapedtype.shape if expected_shape is Ellipsis: return shapedtype # TODO: replace this check with `zip(... strict=True)` # once we stop supporting 3.9 if len(shape) != len(expected_shape): raise ValueError( f"Expected shape: {expected_shape}. Found: {shape}." ) for actual, expected in zip(shape, expected_shape): if expected is not None and actual != expected: raise ValueError( f"Expected shape: {expected_shape}. Found: {shape}." ) return shapedtype return Annotated[ShapeDType, AfterValidator(validate)]
[docs] @classmethod def from_array_annotation(cls, obj: AnnotatedType) -> AnnotatedType: """Create a ShapeDType from an array annotation.""" shape = obj.__metadata__[0].expected_shape dtype = obj.__metadata__[0].expected_dtype return cls[shape, dtype]
class ArrayFlags(IntEnum): """Custom flags for array annotations.""" DIFFERENTIABLE = 1 class ArrayAnnotationType(ABCMeta): """Metaclass for Array type annotation to enforce repr on created types based on class variables. Example: >>> class MyArray(metaclass=ArrayAnnotationType): ... expected_shape = (2, 3) ... expected_dtype = "float32" >>> MyArray MyArray[(2, 3), 'float32'] """ def __repr__(cls) -> str: return f"{cls.__name__}[{cls.expected_shape!r}, {cls.expected_dtype!r}]" def safe_issubclass(obj: Any, baseclass: type[object]) -> bool: """Check if obj is a subclass of baseclass in a way that never raises. (This is useful when obj is not guaranteed to be a type.) """ try: return issubclass(obj, baseclass) except TypeError: return False def _is_annotated(obj: Any) -> bool: """Check if an object is typing.Annotated or typing_extensions.Annotated.""" return get_origin(obj) is Annotated class PydanticArrayAnnotation(metaclass=ArrayAnnotationType): """Base class for Pydantic annotations for NumPy array types. This class provides Pydantic support for arrays with a fixed / polymorphic shape and dtype, with proper validation and serialization. See https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types When serializing or validating pydantic models that contain this annotation you can customize the array encoding to: plain json, base64, or binref. For more details see the docstring of 'Array'. """ # These are class attributes that must be set when the class is created expected_shape: Union[tuple[int, ...], EllipsisType] expected_dtype: str flags: tuple[ArrayFlags] def __init__(self, *args: Any, **kwargs: Any) -> None: raise RuntimeError(f"{self.__class__.__name__} cannot be instantiated") @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: """This method is called by Pydantic to get the core schema for the annotated type. Does most of the heavy lifting for validation and serialization. """ # Create a Pydantic model for the encoded array, for easier validation array_schema = _handler( get_array_model( cls.expected_shape, cls.expected_dtype, [flag.name for flag in cls.flags], ) ) python_to_array_ = partial( python_to_array, expected_shape=cls.expected_shape, expected_dtype=cls.expected_dtype, ) encode_array_ = partial( encode_array, expected_shape=cls.expected_shape, expected_dtype=cls.expected_dtype, ) decode_array_ = partial( decode_array, expected_shape=cls.expected_shape, expected_dtype=cls.expected_dtype, ) load_from_dict_schema = core_schema.chain_schema( # first load / validate JSON, then decode into a NumPy array [ array_schema, core_schema.with_info_plain_validator_function( decode_array_, serialization=core_schema.plain_serializer_function_ser_schema( encode_array_, info_arg=True, return_schema=array_schema, ), ), ] ) return core_schema.json_or_python_schema( json_schema=load_from_dict_schema, python_schema=core_schema.union_schema( [ load_from_dict_schema, # when loading from Python, we also allow any array-like object core_schema.no_info_plain_validator_function(python_to_array_), ], mode="left_to_right", ), serialization=core_schema.plain_serializer_function_ser_schema( encode_array_, info_arg=True, ), ) @classmethod def __get_pydantic_json_schema__( cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: """This method is called by Pydantic to get the JSON schema for the annotated type.""" return handler(_core_schema)
[docs] class Array: """Generic Pydantic type annotation for a multi-dimensional array with a fixed shape and dtype. Arrays will be broadcasted to the expected shape and dtype during validation, but dimensions must match exactly. Polymorphic dimensions are supported by using `None` in the shape tuple. To indicate a scalar, use an empty tuple. Arrays of any shape and rank can be represented by using `...` (ellipsis) as the shape. Example: >>> class MyModel(BaseModel): ... int_array: Array[(2, 3), Int32] ... float_array: Array[(None, 3), Float64] ... scalar_int: Array[(), Int16] ... any_shape_array: Array[..., Float32] You can serialize to (and validate from) different array encodings. >>> model = MyModel( ... int_array=np.array([[1, 2, 3], [4, 5, 6]]), ... float_array=np.array([[1.0, 2.0, 3.0]]), ... scalar_int=np.int32(42), ... any_shape_array=np.array([True, False, True]).reshape(1, 1, 3), ... ) >>> model.model_dump_json(context={"array_encoding": "json"}) >>> model.model_dump_json(context={"array_encoding": "base64"}) or to binref: >>> model.model_dump_json( ... context={ ... "array_encoding": "binref", ... "base_dir": "path/to/base", ... "max_file_size": 10**8, ... } ... ) In the 'binref' case you have to provide a base_dir to save/load binary (.bin) files. The .bin file(s) are written to `context['base_dir'] / f"{context['__binref_uuid']}.bin"`. The '__binref_uuid' is considered an internal variable and should not be modified manually! You can set a 'max_file_size' for the binary files. When this file size (in bytes) is reached, a new __binref_uuid (i.e. a new .bin) is created to append array data to. """ def __init__(self, *args: Any, **kwargs: Any) -> None: clsname = self.__class__.__name__ raise RuntimeError( f"{clsname} cannot be instantiated directly, perhaps you meant to use `{clsname}[(shape), dtype]`?" ) def __class_getitem__( cls, key: tuple[ Union[tuple[Optional[int], ...], EllipsisType], Union[AnnotatedType, str, None], ], ) -> AnnotatedType: """Create a new type annotation based on the given shape and dtype.""" expected_shape, expected_dtype = _ensure_valid_shapedtype(*key) classvars = { "expected_shape": expected_shape, "expected_dtype": expected_dtype, "flags": (), "__module__": cls.__module__, } model = type(cls.__name__, (PydanticArrayAnnotation,), classvars) return Annotated[np.ndarray, model]
def is_array_annotation(obj: Any) -> bool: """Check if an object is a Pydantic array type annotation.""" if _is_annotated(obj): metadata = obj.__metadata__[0] if safe_issubclass(metadata, PydanticArrayAnnotation): return True return False
[docs] class Differentiable: """Type annotation for a differentiable array. Example: >>> class MyModel(BaseModel): ... array: Differentiable[Array[(None, 3), Float64]] """ def __class_getitem__(cls, key: Any) -> AnnotatedType: """Mark wrapped array type as differentiable.""" if not is_array_annotation(key): raise ValueError("Differentiable can only be applied to Array types") arr = key.__metadata__[0] # Create a new array type with the DIFFERENTIABLE flag, to not modify the original type in-place newarr = type( arr.__name__, (arr,), {"flags": (*arr.flags, ArrayFlags.DIFFERENTIABLE)} ) return Annotated[key.__origin__, newarr]
def is_differentiable(obj: Any) -> bool: """Check if an object is a Differentiable array type annotation.""" if is_array_annotation(obj): return ArrayFlags.DIFFERENTIABLE in obj.__metadata__[0].flags return False # Export concrete scalar types Float16 = Array[(), "float16"] Float32 = Array[(), "float32"] Float64 = Array[(), "float64"] Int8 = Array[(), "int8"] Int16 = Array[(), "int16"] Int32 = Array[(), "int32"] Int64 = Array[(), "int64"] Bool = Array[(), "bool"] UInt8 = Array[(), "uint8"] UInt16 = Array[(), "uint16"] UInt32 = Array[(), "uint32"] UInt64 = Array[(), "uint64"] Complex64 = Array[(), "complex64"] Complex128 = Array[(), "complex128"]