Source code for pronoms.normalizers.median_polish_normalizer
import matplotlib.pyplot as plt
import numpy as np
from ..utils.validators import check_nan_inf, validate_input_data
[docs]
class MedianPolishNormalizer:
"""
Normalizer based on Tukey's Median Polish algorithm.
This algorithm iteratively removes median effects from rows (samples) and
columns (features) of a matrix, typically applied to log-transformed data.
It decomposes the data `X` into:
`X[i, j] = overall_median + row_effect[i] + col_effect[j] + residual[i, j]`
The normalized data returned is typically the `residuals + overall_median`,
transformed back to the original scale if log transformation was used.
Attributes
----------
max_iterations : int
Maximum number of iterations allowed for the algorithm.
tolerance : float
Convergence threshold on the per-iteration row/column medians removed
from the residual matrix. The algorithm stops as soon as
``max(|row_medians|.max(), |col_medians|.max()) <= tolerance``.
epsilon : float
Small constant added before log transformation to handle non-positive values.
log_transform : bool
Whether to apply log transformation before median polish and back-transform after.
row_effects : Optional[np.ndarray]
The calculated median effects for each row (sample). Available after normalize().
col_effects : Optional[np.ndarray]
The calculated median effects for each column (feature). Available after normalize().
overall_median : Optional[float]
The calculated overall median effect. Available after normalize().
residuals : Optional[np.ndarray]
The final residuals after removing row, column, and overall effects.
Available after normalize().
converged : Optional[bool]
Whether the algorithm converged within max_iterations. Available after normalize().
iterations_run : Optional[int]
Number of iterations actually performed. Available after normalize().
"""
def __init__(
self,
max_iterations: int = 10,
tolerance: float = 0.01,
epsilon: float = 1e-6,
log_transform: bool = True,
):
"""
Initialize the MedianPolishNormalizer.
Parameters
----------
max_iterations : int, optional
Maximum number of iterations, by default 10.
tolerance : float, optional
Convergence tolerance on the row/column medians removed in each
iteration. The polish stops when the largest absolute row- or
column-median in the current iteration is ``<= tolerance``.
By default 0.01.
epsilon : float, optional
Small constant added before log transformation (if used), by default 1e-6.
Only used if `log_transform` is True.
log_transform : bool, optional
Whether to log-transform the data before applying median polish and
exponentiate after, by default True.
"""
if not isinstance(max_iterations, int) or max_iterations <= 0:
raise ValueError("max_iterations must be a positive integer")
if not isinstance(tolerance, (int, float)) or tolerance < 0:
raise ValueError("tolerance must be a non-negative number")
if not isinstance(epsilon, (int, float)) or epsilon < 0:
raise ValueError("epsilon must be a non-negative number")
if not isinstance(log_transform, bool):
raise ValueError("log_transform must be a boolean")
self.max_iterations = max_iterations
self.tolerance = tolerance
self.epsilon = epsilon
self.log_transform = log_transform
# Results attributes initialized to None
self.row_effects: np.ndarray | None = None
self.col_effects: np.ndarray | None = None
self.overall_median: float | None = None
self.residuals: np.ndarray | None = None
self.converged: bool | None = None
self.iterations_run: int | None = None
[docs]
def normalize(self, X: np.ndarray) -> np.ndarray:
"""
Apply Tukey's Median Polish normalization to the data.
If log_transform is True, the input data is log-transformed before polishing.
The method returns the normalized data defined as overall_median + residuals.
Note: If log_transform was used, the returned data remains in log-space.
Parameters
----------
X : np.ndarray
Input data matrix (n_samples, n_features).
Returns
-------
np.ndarray
Normalized data matrix (overall_median + residuals).
If log_transform=True, this matrix is in log-space.
"""
X = validate_input_data(X)
# log / shift handling ------------------------------------------------
Xp = np.log(X + self.epsilon) if self.log_transform else X.copy()
# second sanity check
has_nan_inf, _ = check_nan_inf(Xp)
if has_nan_inf:
raise ValueError("Input contains NaN or Inf values.")
n_rows, n_cols = Xp.shape
# initialise effects --------------------------------------------------
self.row_effects = np.zeros(n_rows)
self.col_effects = np.zeros(n_cols)
# Local accumulator keeps mypy happy (Optional[float] on the attribute
# only widens at the end). The value is also assigned to
# ``self.overall_median`` after the loop completes.
overall_median: float = 0.0
resid = Xp.copy()
# iterative polish ----------------------------------------------------
self.converged = False
iterations_run = 0
for _ in range(self.max_iterations):
iterations_run += 1
# ----- row step --------------------------------------------------
row_med = np.median(resid, axis=1)
resid -= row_med[:, None]
self.row_effects += row_med
# centre row effects and update overall
rm = float(np.median(self.row_effects))
self.row_effects -= rm
overall_median += rm
# ----- column step ----------------------------------------------
col_med = np.median(resid, axis=0)
resid -= col_med
self.col_effects += col_med
# centre column effects and update overall
cm = float(np.median(self.col_effects))
self.col_effects -= cm
overall_median += cm
# ----- convergence check ----------------------------------------
max_change = max(np.abs(row_med).max(), np.abs(col_med).max())
if max_change <= self.tolerance:
self.converged = True
break
self.iterations_run = iterations_run
self.overall_median = overall_median
# store residuals
self.residuals = resid
# return log-space normalized matrix
return overall_median + resid
[docs]
def plot_comparison(
self,
original_data: np.ndarray,
normalized_data: np.ndarray,
figsize: tuple[int, int] = (10, 8),
) -> plt.Figure:
"""
Generate a hexbin plot comparing original data (log scale) vs. normalized data.
If log_transform was used during normalization, the normalized data (y-axis)
will be in log-space. The original data (x-axis) is always plotted on a log scale
for comparison consistency, especially when normalization involved log transform.
Parameters
----------
original_data : np.ndarray
The raw data matrix (n_samples, n_features).
normalized_data : np.ndarray
The data matrix after normalization (n_samples, n_features).
This will be in log-space if log_transform=True was used.
figsize : Tuple[int, int], optional
Figure size for the plot, by default (10, 8).
Returns
-------
plt.Figure
Matplotlib figure object containing the hexbin plot.
"""
fig, ax = plt.subplots(figsize=figsize)
# Flatten data for hexbin plot
x_flat = original_data.flatten()
y_flat = normalized_data.flatten()
# Filter out non-positive values for log scale on x-axis
valid_indices = x_flat > 0
x_filtered = x_flat[valid_indices]
y_filtered = y_flat[valid_indices]
if len(x_filtered) == 0:
ax.text(0.5, 0.5, "No positive data to plot on log scale", ha="center", va="center")
ax.set_title("Median Polish Normalization Comparison")
ax.set_xlabel("Original Data (Log Scale)")
ax.set_ylabel("Normalized Data")
return fig
# Create hexbin plot
hb = ax.hexbin(x_filtered, y_filtered, gridsize=50, cmap="viridis", xscale="log")
fig.colorbar(hb, ax=ax, label="Count in bin")
# Reference line. The x-axis is ``xscale='log'``, so x-coords are
# passed as raw values (matplotlib places them logarithmically on the
# axis). y-coords match the y data scale.
x_min = float(x_filtered.min())
x_max = float(x_filtered.max())
if self.log_transform:
ax.plot(
[x_min, x_max],
[np.log(x_min), np.log(x_max)],
color="red",
linestyle="--",
linewidth=1,
label="y = log(x)",
)
ax.set_ylabel("Normalized Data (Log Scale)")
else:
ax.plot(
[x_min, x_max],
[x_min, x_max],
color="red",
linestyle="--",
linewidth=1,
label="y = x",
)
ax.set_ylabel("Normalized Data")
ax.set_title("Median Polish Normalization Comparison")
ax.set_xlabel("Original Data (Log Scale)")
ax.legend()
fig.tight_layout()
return fig