Source code for statistics.nnPreprocessing

#!/usr/bin/env python3
"""
.. module:: preprocessing_nnAdapter
   :synopsis: Preprocessing and inverse-preprocessing functions for the
   NNAdapter. Handles feature standardisation and nLL postprocessing,
   including uncertainty propagation through arbitrary transformation chains.

"""

__all__ = [
    "preprocess_features",
    "postprocess_nLLs",
    "postprocess_nLLs_errors",
]

import numpy as np
from typing import Optional


# ---------------------------------------------------------------------------
# Elementary transforms (forward and inverse)
# ---------------------------------------------------------------------------

def _log_with_negatives(x: np.ndarray) -> np.ndarray:
    """Signed log1p transform, works for negative values, exactly invertible."""
    return np.sign(x) * np.log1p(np.abs(x))


def _undo_log_with_negatives(x: np.ndarray) -> np.ndarray:
    """Exact inverse of _log_with_negatives."""
    return np.sign(x) * np.expm1(np.abs(x))


def _standardize(
    x: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
) -> np.ndarray:
    """Apply standardisation using pre-computed mean and std."""
    return (x - mean) / std


def _undo_standardize(
    x: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
) -> np.ndarray:
    """Invert standardisation."""
    return x * std + mean


# ---------------------------------------------------------------------------
# Feature preprocessing (forward, used at inference time)
# ---------------------------------------------------------------------------

def _get_fn(fn_str: str):
    """Return a named scalar transform."""
    match fn_str:
        case "log":
            return np.log
        case "exp":
            return np.exp
        case "sqrt":
            return np.sqrt
        case "inverse":
            return lambda x: 1.0 / x
        case "log_w_negatives":
            return _log_with_negatives
        case _:
            raise ValueError(f"Unknown transform '{fn_str}'.")


[docs]def preprocess_features( features_raw: np.ndarray, trafos: Optional[dict] = None, mean: Optional[np.ndarray] = None, std: Optional[np.ndarray] = None, ) -> tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: """Apply feature preprocessing using saved mean/std from training. :param features_raw: raw input features, shape (batch, n_features) :param trafos: dict mapping set-name to list of transform strings, e.g. ``{"fvs_standardized": ["log_w_negatives", "standardization"]}`` :param mean: per-feature mean saved at training time (required when "standardization" is in trafos) :param std: per-feature std saved at training time (required when "standardization" is in trafos) :returns: (scaled_features, mean, std) """ assert np.isfinite(features_raw).all(), "Non-finite values in features_raw" feature_sets = {} if trafos: for set_name, fn_list in trafos.items(): transformed = features_raw.copy() for fn_str in fn_list: if fn_str == "standardization": assert mean is not None and std is not None, \ "mean and std must be provided for standardization at inference time" transformed = _standardize(transformed, mean, std) else: transformed = _get_fn(fn_str)(transformed) feature_sets[set_name] = transformed scaled = list ( map ( float, feature_sets["fvs_standardized"] ) ) # scaled = np.concatenate(list(feature_sets.values()), axis=1) if feature_sets else features_raw return scaled, mean, std
# --------------------------------------------------------------------------- # nLL postprocessing (inverse transforms) # --------------------------------------------------------------------------- def _get_inv_fn(fn_str: str): """Return the inverse of a named scalar transform.""" match fn_str: case "log": return np.exp case "exp": return np.log case "sqrt": return lambda x: x ** 2 case "inverse": return lambda x: 1.0 / x case "log_w_negatives": return _undo_log_with_negatives case _: raise ValueError(f"No known inverse for transform '{fn_str}'.")
[docs]def postprocess_nLLs( nLLs: np.ndarray, mean: np.ndarray, std: np.ndarray, trafos: Optional[list] = None, ) -> np.ndarray: """Undo nLL preprocessing in reverse order. :param nLLs: preprocessed nLL deltas output by the NN, shape (4,) :param mean: per-output mean saved at training time :param std: per-output std saved at training time :param trafos: list of transform strings applied during training, e.g. ``["log_w_negatives", "standardization"]`` :returns: unpreprocessed nLL deltas, same shape as nLLs """ if trafos: for fn_str in reversed(trafos): if fn_str == "standardization": nLLs = _undo_standardize(nLLs, mean, std) else: nLLs = _get_inv_fn(fn_str)(nLLs) assert np.isfinite(nLLs).all(), "Non-finite values after postprocess_nLLs" return nLLs
[docs]def postprocess_nLLs_errors( errors: np.ndarray, nLLs_prepd: np.ndarray, mean: np.ndarray, std: np.ndarray, trafos: Optional[list] = None, eps: float = 1e-5, ) -> np.ndarray: """Propagate heteroskedastic errors through the inverse nLL preprocessing. Uses central-difference numerical differentiation to compute ``|d(postprocess)/d(nLL)| * sigma`` for each output independently. :param errors: NN-predicted uncertainties on the preprocessed deltas, shape (4,), these are the errors on nLLs_prepd :param nLLs_prepd: the central (mean) preprocessed nLL deltas, shape (4,) :param mean: per-output mean saved at training time :param std: per-output std saved at training time :param trafos: same transform list used in postprocess_nLLs :param eps: step size for numerical differentiation :returns: propagated uncertainties on the unpreprocessed nLL deltas, shape (4,) """ f_plus = postprocess_nLLs(nLLs_prepd + eps, mean, std, trafos) f_minus = postprocess_nLLs(nLLs_prepd - eps, mean, std, trafos) deriv = np.abs(f_plus - f_minus) / (2.0 * eps) return deriv * errors