Source code for pronoms.normalizers.median_normalizer
"""
Median Normalizer for proteomics data.
This module provides a class for median normalization of proteomics data.
"""
import matplotlib.pyplot as plt
import numpy as np
from ..utils.plotting import create_hexbin_comparison
from ..utils.validators import check_nan_inf, validate_input_data
[docs]
class MedianNormalizer:
"""
Normalizer that scales each sample by its median.
This normalizer adjusts each sample (row) in the data matrix by dividing
that sample by its own median and then multiplying by the mean of all
sample medians. After normalization every row's median equals
``mean_of_medians``, preserving the overall scale of the dataset rather
than collapsing every row to a median of 1.
Inputs are arranged as ``(n_samples, n_features)`` (rows are samples,
columns are proteins/features), following the sklearn convention.
Attributes
----------
scaling_factors : Optional[np.ndarray]
Per-sample medians used as the divisor (one value per row).
Only available after calling normalize().
mean_of_medians : Optional[float]
Mean of ``scaling_factors``; the common value every row's median is
rescaled to. Only available after calling normalize().
"""
def __init__(self):
"""Initialize the MedianNormalizer."""
self.scaling_factors = None
self.mean_of_medians = None
[docs]
def normalize(self, X: np.ndarray) -> np.ndarray:
"""
Perform median normalization on input data X.
Parameters
----------
X : np.ndarray
Input data matrix with shape (n_samples, n_features).
Each row represents a sample, each column represents a feature/protein.
Returns
-------
np.ndarray
Normalized data matrix with the same shape as X.
Raises
------
ValueError
- If input is not a 2D array with at least one feature.
- If input data contains NaN or Inf values.
- If any sample’s median is ≤ 0 (protein quantities must be positive).
"""
# Dimensionality guard
if X.ndim != 2 or X.shape[1] == 0:
raise ValueError("X must be a 2D array with at least one feature (n_samples, n_features).")
# Validate input data (dtype conversion, etc.)
X = validate_input_data(X)
# Check for NaN or Inf values
has_nan_inf, _ = check_nan_inf(X)
if has_nan_inf:
raise ValueError("Input data contains NaN or Inf values. Please handle these values before normalization.")
# Compute per-sample medians
medians = np.median(X, axis=1)
# Enforce strictly positive medians
if np.any(medians <= 0):
raise ValueError("All sample medians must be > 0.")
# Store scaling state
mean_of_medians = float(np.mean(medians))
medians = medians.reshape(-1, 1)
self.scaling_factors = medians.flatten()
self.mean_of_medians = mean_of_medians
# Apply normalization
normalized_data = (X / medians) * mean_of_medians
return normalized_data
[docs]
def plot_comparison(
self,
before_data: np.ndarray,
after_data: np.ndarray,
figsize: tuple[int, int] = (10, 8),
title: str = "Median Normalization Comparison",
log_axes: bool = True,
) -> plt.Figure:
"""
Plot data before vs after normalization using a 2D hexbin density plot.
Parameters
----------
before_data : np.ndarray
Data before normalization, shape (n_samples, n_features).
after_data : np.ndarray
Data after normalization, shape (n_samples, n_features).
figsize : Tuple[int, int], optional
Figure size, by default (10, 8).
title : str, optional
Plot title, by default "Median Normalization Comparison".
log_axes : bool, optional
If True (default), plot log10 of the values on both axes. If False, plot raw values.
Returns
-------
plt.Figure
Figure object containing the hexbin density plot.
"""
# Validate input data
before_data = validate_input_data(before_data)
after_data = validate_input_data(after_data)
# Create hexbin comparison plot
fig = create_hexbin_comparison(
before_data,
after_data,
figsize=figsize,
title=title,
xlabel="Before Median Normalization",
ylabel="After Median Normalization",
log_axes=log_axes,
)
return fig