Source code for torch_concepts.data.scalers.standard
"""Standard scaling (z-score normalization) for data preprocessing.
This module provides StandardScaler for normalizing data to zero mean and
unit variance, similar to scikit-learn's StandardScaler but for PyTorch tensors.
"""
from typing import Tuple, Union
import torch
from torch import Tensor
from ..base.scaler import Scaler
def zeros_to_one_(scale: Union[float, Tensor]) -> Union[float, Tensor]:
"""Set to 1 scales of near-constant features to avoid division by zero.
Detects features with near-zero variance (within machine precision) and
sets their scale to 1.0 to prevent numerical instability. Operates in-place
for tensor inputs.
Adapted from sklearn.preprocessing._data._handle_zeros_in_scale and
tsl.data.preprocessing.scalers.zeros_to_one_
Args:
scale (Union[float, Tensor]): Scalar or tensor of scale values to check.
Returns:
Union[float, Tensor]: Modified scale with near-zero values replaced by 1.0.
Example:
>>> scales = torch.tensor([1.0, 0.0000001, 2.5, 0.0])
>>> zeros_to_one_(scales)
tensor([1.0000, 1.0000, 2.5000, 1.0000])
"""
if isinstance(scale, (int, float)):
return 1.0 if torch.isclose(torch.tensor(scale), torch.tensor(0.0)).item() else scale
eps = 10 * torch.finfo(scale.dtype).eps
zeros = torch.isclose(scale, torch.tensor(0.0, device=scale.device, dtype=scale.dtype), atol=eps, rtol=eps)
scale[zeros] = 1.0
return scale
[docs]
class StandardScaler(Scaler):
"""Z-score normalization scaler for PyTorch tensors.
Standardizes features by removing the mean and scaling to unit variance:
z = (x - μ) / σ
This scaler is useful for:
- Normalizing input features before training
- Ensuring all features are on the same scale
- Improving gradient flow and training stability
Args:
axis (Union[int, Tuple], optional): Axis or axes along which to compute
mean and standard deviation. Typically 0 (across samples) for
feature-wise normalization. Defaults to 0.
Attributes:
mean (Tensor): Computed mean value(s) from fitted data.
std (Tensor): Computed standard deviation(s) from fitted data.
Example:
>>> # Normalize a batch of features
>>> scaler = StandardScaler(axis=0)
>>> X_train = torch.randn(1000, 50) # 1000 samples, 50 features
>>> X_train_scaled = scaler.fit_transform(X_train)
>>>
>>> # Transform test data using training statistics
>>> X_test = torch.randn(200, 50)
>>> X_test_scaled = scaler.transform(X_test)
>>>
>>> # Inverse transform to original scale
>>> X_recovered = scaler.inverse_transform(X_test_scaled)
"""
[docs]
def __init__(self, axis: Union[int, Tuple] = 0):
"""Initialize the StandardScaler.
Args:
axis: Axis or axes along which to compute statistics (default: 0).
"""
super(StandardScaler, self).__init__()
self.axis = axis
[docs]
def fit(self, x: Tensor) -> "StandardScaler":
"""Compute mean and standard deviation along specified dimension.
Args:
x: Input tensor to compute statistics from.
Returns:
self: The fitted scaler instance for method chaining.
"""
self.mean = x.mean(dim=self.axis, keepdim=True)
self.std = x.std(dim=self.axis, keepdim=True)
self.std = zeros_to_one_(self.std)
return self