Source code for torch_concepts.data.base.scaler
"""Abstract base class for data scaling transformations.
This module defines the Scaler interface that all data scalers must implement.
Scalers are used to normalize and denormalize data during training and inference.
"""
from abc import ABC, abstractmethod
from torch import Tensor
[docs]
class Scaler(ABC):
"""Abstract base class for data scaling transformations.
Provides a consistent interface for fitting scalers to data and applying
forward/inverse transformations. All concrete scaler implementations should
inherit from this class and implement fit(), transform(), and
inverse_transform() methods.
Args:
bias (float, optional): Initial bias value. Defaults to 0.0.
scale (float, optional): Initial scale value. Defaults to 1.0.
Example:
>>> class MinMaxScaler(Scaler):
... def fit(self, x, dim=0):
... self.min = x.min(dim=dim, keepdim=True)[0]
... self.max = x.max(dim=dim, keepdim=True)[0]
... return self
...
... def transform(self, x):
... return (x - self.min) / (self.max - self.min)
...
... def inverse_transform(self, x):
... return x * (self.max - self.min) + self.min
"""
[docs]
def __init__(self, bias=0., scale=1.):
self.bias = bias
self.scale = scale
super(Scaler, self).__init__()
[docs]
@abstractmethod
def fit(self, x: Tensor, dim: int = 0) -> "Scaler":
"""Fit the scaler to the input data.
Args:
x: Input tensor to fit the scaler to.
dim: Dimension along which to compute statistics (default: 0).
Returns:
self: The fitted scaler instance for method chaining.
"""
pass