# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Metrics, Parameters, and Artifacts (MPA) library for Tesseract Core."""
import csv
import json
import os
import shutil
import sys
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import ExitStack, contextmanager
from contextvars import ContextVar
from datetime import datetime
from io import UnsupportedOperation
from pathlib import Path
from typing import Any, Optional, Union
import requests
from tesseract_core.runtime.config import get_config
from tesseract_core.runtime.logs import LogPipe
class BaseBackend(ABC):
"""Base class for MPA backends."""
def __init__(self, base_dir: Optional[str] = None) -> None:
if base_dir is None:
base_dir = get_config().output_path
self.log_dir = Path(base_dir) / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True)
@abstractmethod
def log_parameter(self, key: str, value: Any) -> None:
"""Log a parameter."""
pass
@abstractmethod
def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None:
"""Log a metric."""
pass
@abstractmethod
def log_artifact(self, local_path: str) -> None:
"""Log an artifact."""
pass
@abstractmethod
def start_run(self) -> None:
"""Start a new run."""
pass
@abstractmethod
def end_run(self) -> None:
"""End the current run."""
pass
class FileBackend(BaseBackend):
"""MPA backend that writes to local files."""
def __init__(self, base_dir: Optional[str] = None) -> None:
super().__init__(base_dir)
# Initialize log files
self.params_file = self.log_dir / "parameters.json"
self.metrics_file = self.log_dir / "metrics.csv"
self.artifacts_dir = self.log_dir / "artifacts"
self.artifacts_dir.mkdir(exist_ok=True)
# Initialize parameters dict and metrics list
self.parameters = {}
self.metrics = []
# Initialize CSV file with headers
with open(self.metrics_file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "key", "value", "step"])
def log_parameter(self, key: str, value: Any) -> None:
"""Log a parameter to JSON file."""
self.parameters[key] = value
with open(self.params_file, "w") as f:
json.dump(self.parameters, f, indent=2, default=str)
def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None:
"""Log a metric to CSV file."""
timestamp = datetime.now().isoformat()
step_value = (
step
if step is not None
else len([m for m in self.metrics if m["key"] == key])
)
metric_entry = {
"timestamp": timestamp,
"key": key,
"value": value,
"step": step_value,
}
self.metrics.append(metric_entry)
with open(self.metrics_file, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([timestamp, key, value, step_value])
def log_artifact(self, local_path: str) -> None:
"""Copy artifact to the artifacts directory."""
source_path = Path(local_path)
if not source_path.exists():
raise FileNotFoundError(f"Artifact file not found: {local_path}")
dest_path = self.artifacts_dir / source_path.name
shutil.copy2(source_path, dest_path)
def start_run(self) -> None:
"""Start a new run. File backend doesn't need special start logic."""
pass
def end_run(self) -> None:
"""End the current run. File backend doesn't need special end logic."""
pass
class MLflowBackend(BaseBackend):
"""MPA backend that writes to an MLflow tracking server."""
def __init__(self, base_dir: Optional[str] = None) -> None:
super().__init__(base_dir)
os.environ["GIT_PYTHON_REFRESH"] = (
"quiet" # Suppress potential MLflow git warnings
)
try:
import mlflow
except ImportError as exc:
raise ImportError(
"MLflow is required for MLflowBackend but is not installed"
) from exc
self._ensure_mlflow_reachable()
self.mlflow = mlflow
config = get_config()
tracking_uri = config.mlflow_tracking_uri
if not tracking_uri.startswith(("http://", "https://")):
# If it's a file URI, convert to local path
tracking_uri = tracking_uri.replace("file://", "")
# Relative paths are resolved against the base output path
if not Path(tracking_uri).is_absolute():
tracking_uri = (Path(get_config().output_path) / tracking_uri).resolve()
mlflow.set_tracking_uri(tracking_uri)
def _ensure_mlflow_reachable(self) -> None:
"""Check if the MLflow tracking server is reachable."""
config = get_config()
mlflow_tracking_uri = config.mlflow_tracking_uri
if mlflow_tracking_uri.startswith(("http://", "https://")):
try:
response = requests.get(mlflow_tracking_uri, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(
f"Failed to connect to MLflow tracking server at {mlflow_tracking_uri}. "
"Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, "
"or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string."
) from e
def log_parameter(self, key: str, value: Any) -> None:
"""Log a parameter to MLflow."""
self.mlflow.log_param(key, value)
def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None:
"""Log a metric to MLflow."""
self.mlflow.log_metric(key, value, step=step)
def log_artifact(self, local_path: str) -> None:
"""Log an artifact to MLflow."""
self.mlflow.log_artifact(local_path)
def start_run(self) -> None:
"""Start a new MLflow run."""
self.mlflow.start_run()
def end_run(self) -> None:
"""End the current MLflow run."""
self.mlflow.end_run()
def _create_backend(base_dir: Optional[str]) -> BaseBackend:
"""Create the appropriate backend based on environment."""
config = get_config()
if config.mlflow_tracking_uri:
return MLflowBackend(base_dir)
else:
return FileBackend(base_dir)
# Context variable for the current backend instance
_current_backend: ContextVar[BaseBackend] = ContextVar("current_backend")
def _get_current_backend() -> BaseBackend:
"""Get the current backend instance from context variable."""
try:
return _current_backend.get()
except LookupError as exc:
raise RuntimeError(
"No active MPA run. Use 'with mpa.start_run():' to start a run."
) from exc
# Public API functions that work with the current context
[docs]
def log_parameter(key: str, value: Any) -> None:
"""Log a parameter to the current run context."""
_get_current_backend().log_parameter(key, value)
[docs]
def log_metric(key: str, value: float, step: Optional[int] = None) -> None:
"""Log a metric to the current run context."""
_get_current_backend().log_metric(key, value, step)
[docs]
def log_artifact(local_path: str) -> None:
"""Log an artifact to the current run context."""
_get_current_backend().log_artifact(local_path)
@contextmanager
def redirect_stdio(logfile: Union[str, Path]) -> Generator[None, None, None]:
"""Context manager for redirecting stdout and stderr to a custom pipe.
Writes messages to both the original stderr and the given logfile.
"""
from tesseract_core.runtime.core import redirect_fd
try:
# Check if a file descriptor is available
sys.stdout.fileno()
sys.stderr.fileno()
except UnsupportedOperation:
# Don't redirect if stdout/stderr are not file descriptors
# (This likely means that streams are already redirected)
yield
return
with ExitStack() as stack:
f = stack.enter_context(open(logfile, "w"))
orig_stderr = sys.stderr
# Use `print` instead of `.write` so we get appropriate newlines and flush behavior
write_to_stderr = lambda msg: print(msg, file=orig_stderr, flush=True)
write_to_file = lambda msg: print(msg, file=f, flush=True)
pipe_fd = stack.enter_context(LogPipe(write_to_stderr, write_to_file))
# Redirect file descriptors at OS level
stack.enter_context(redirect_fd(sys.stdout, pipe_fd))
stack.enter_context(redirect_fd(sys.stderr, pipe_fd))
yield
@contextmanager
def start_run(base_dir: Optional[str] = None) -> Generator[None, None, None]:
"""Context manager for starting and ending a run."""
backend = _create_backend(base_dir)
token = _current_backend.set(backend)
backend.start_run()
logfile = backend.log_dir / "tesseract.log"
try:
with redirect_stdio(logfile):
yield
finally:
backend.end_run()
_current_backend.reset(token)