Source code for pronoms.normalizers.median_normalizer

"""
Median Normalizer for proteomics data.

This module provides a class for median normalization of proteomics data.
"""

import matplotlib.pyplot as plt
import numpy as np

from ..utils.plotting import create_hexbin_comparison
from ..utils.validators import check_nan_inf, validate_input_data


[docs] class MedianNormalizer: """ Normalizer that scales each sample by its median. This normalizer adjusts each sample (row) in the data matrix by dividing that sample by its own median and then multiplying by the mean of all sample medians. After normalization every row's median equals ``mean_of_medians``, preserving the overall scale of the dataset rather than collapsing every row to a median of 1. Inputs are arranged as ``(n_samples, n_features)`` (rows are samples, columns are proteins/features), following the sklearn convention. Attributes ---------- scaling_factors : Optional[np.ndarray] Per-sample medians used as the divisor (one value per row). Only available after calling normalize(). mean_of_medians : Optional[float] Mean of ``scaling_factors``; the common value every row's median is rescaled to. Only available after calling normalize(). """ def __init__(self): """Initialize the MedianNormalizer.""" self.scaling_factors = None self.mean_of_medians = None
[docs] def normalize(self, X: np.ndarray) -> np.ndarray: """ Perform median normalization on input data X. Parameters ---------- X : np.ndarray Input data matrix with shape (n_samples, n_features). Each row represents a sample, each column represents a feature/protein. Returns ------- np.ndarray Normalized data matrix with the same shape as X. Raises ------ ValueError - If input is not a 2D array with at least one feature. - If input data contains NaN or Inf values. - If any sample’s median is ≤ 0 (protein quantities must be positive). """ # Dimensionality guard if X.ndim != 2 or X.shape[1] == 0: raise ValueError("X must be a 2D array with at least one feature (n_samples, n_features).") # Validate input data (dtype conversion, etc.) X = validate_input_data(X) # Check for NaN or Inf values has_nan_inf, _ = check_nan_inf(X) if has_nan_inf: raise ValueError("Input data contains NaN or Inf values. Please handle these values before normalization.") # Compute per-sample medians medians = np.median(X, axis=1) # Enforce strictly positive medians if np.any(medians <= 0): raise ValueError("All sample medians must be > 0.") # Store scaling state mean_of_medians = float(np.mean(medians)) medians = medians.reshape(-1, 1) self.scaling_factors = medians.flatten() self.mean_of_medians = mean_of_medians # Apply normalization normalized_data = (X / medians) * mean_of_medians return normalized_data
[docs] def plot_comparison( self, before_data: np.ndarray, after_data: np.ndarray, figsize: tuple[int, int] = (10, 8), title: str = "Median Normalization Comparison", log_axes: bool = True, ) -> plt.Figure: """ Plot data before vs after normalization using a 2D hexbin density plot. Parameters ---------- before_data : np.ndarray Data before normalization, shape (n_samples, n_features). after_data : np.ndarray Data after normalization, shape (n_samples, n_features). figsize : Tuple[int, int], optional Figure size, by default (10, 8). title : str, optional Plot title, by default "Median Normalization Comparison". log_axes : bool, optional If True (default), plot log10 of the values on both axes. If False, plot raw values. Returns ------- plt.Figure Figure object containing the hexbin density plot. """ # Validate input data before_data = validate_input_data(before_data) after_data = validate_input_data(after_data) # Create hexbin comparison plot fig = create_hexbin_comparison( before_data, after_data, figsize=figsize, title=title, xlabel="Before Median Normalization", ylabel="After Median Normalization", log_axes=log_axes, ) return fig