Source code for torch_concepts.data.utils
"""
Data utility functions for tensor manipulation and transformation.
This module provides utility functions for data processing, including tensor
conversion, image colorization, and affine transformations.
"""
import os
import numpy as np
import pandas as pd
import logging
from typing import Any, List, Sequence, Union
import torch
import random
from torch import Tensor
from torchvision.transforms import v2
logger = logging.getLogger(__name__)
[docs]
def ensure_list(value: Any) -> List:
"""
Ensure a value is converted to a list. If the value is iterable (but not a
string or dict), converts it to a list. Otherwise, wraps it in a list.
Args:
value: Any value to convert to list.
Returns:
List: The value as a list.
Examples:
>>> ensure_list([1, 2, 3])
[1, 2, 3]
>>> ensure_list((1, 2, 3))
[1, 2, 3]
>>> ensure_list(5)
[5]
>>> ensure_list("hello")
['hello']
>>> ensure_list({'a': 1, 'b': 2}) # doctest: +SKIP
TypeError: Cannot convert dict to list. Use list(dict.values())
or list(dict.keys()) explicitly.
"""
# Explicitly reject dictionaries to avoid silent conversion to keys
if isinstance(value, dict):
raise TypeError(
"Cannot convert dict to list. Use list(dict.values()) or " \
"list(dict.keys()) explicitly to make your intent clear."
)
# Check for iterables (but not strings)
if hasattr(value, '__iter__') and not isinstance(value, str):
return list(value)
else:
return [value]
[docs]
def files_exist(files: Sequence[str]) -> bool:
"""
Check if all files in a sequence exist.
Args:
files: Sequence of file paths to check.
Returns:
bool: True if all files exist, False otherwise.
Returns True for empty sequences (vacuous truth).
"""
files = ensure_list(files)
return all([os.path.exists(f) for f in files])
[docs]
def parse_tensor(data: Union[np.ndarray, pd.DataFrame, Tensor, list],
name: str,
precision: Union[int, str]) -> Tensor:
"""
Convert input data to torch tensor with appropriate format.
Supports conversion from numpy arrays, pandas DataFrames, or existing tensors.
Args:
data: Input data as numpy array, DataFrame, Tensor, list.
name: Name of the data (for error messages).
precision: Desired numerical precision (16, 32, or 64).
Returns:
Tensor: Converted tensor with specified precision.
Raises:
TypeError: If data is not in a supported format.
"""
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
return convert_precision(data, precision)
elif isinstance(data, pd.DataFrame):
data = torch.tensor(data.values)
return convert_precision(data, precision)
elif isinstance(data, Tensor):
return convert_precision(data, precision)
elif isinstance(data, list):
data = data
return data
else:
raise TypeError(f"{name} must be np.ndarray, \
pd.DataFrame, torch.Tensor or a list of filenames, got {type(data)}.")
[docs]
def convert_precision(tensor: Tensor,
precision: Union[int, str]) -> Tensor:
"""
Convert tensor to specified precision.
Args:
tensor: Input tensor.
precision: Target precision ("float16", "float32", or "float64", or 16, 32, 64).
Returns:
Tensor: Tensor converted to specified precision.
"""
if precision == "float32":
tensor = tensor.to(torch.float32)
elif precision == "float64":
tensor = tensor.to(torch.float64)
elif precision == "float16":
tensor = tensor.to(torch.float16)
return tensor
def resolve_size(size: Union[int, float], n_samples: int) -> int:
"""Convert size specification to absolute number of samples.
Args:
size: Either an integer (absolute count) or float (fraction in [0, 1]).
n_samples: Total number of samples in dataset.
Returns:
int: Absolute number of samples.
Raises:
ValueError: If fractional size is not in [0, 1] or absolute size is negative.
TypeError: If size is neither int nor float.
"""
if isinstance(size, float):
if not 0.0 <= size <= 1.0:
raise ValueError(f"Fractional size must be in [0, 1], got {size}")
return int(size * n_samples)
elif isinstance(size, int):
if size < 0:
raise ValueError(f"Absolute size must be non-negative, got {size}")
return size
else:
raise TypeError(f"Size must be int or float, got {type(size).__name__}")
[docs]
def colorize(images, colors):
"""
Colorize grayscale images based on specified colors.
Converts grayscale images to RGB by assigning the intensity to one
of three color channels (red, green, or blue).
Args:
images: Tensor of shape (N, H, W) containing grayscale images.
colors: Tensor of shape (N) containing color labels (0=red, 1=green, 2=blue).
Returns:
Tensor: Colored images of shape (N, 3, H, W).
Raises:
AssertionError: If colors contain values other than 0, 1, or 2.
"""
assert torch.unique(colors).shape[0] <= 3, "colors must be 0, 1, or 2 (red, green, blue)."
N = images.shape[0]
colored_images = torch.zeros((N, 3, images.shape[1], images.shape[2]), dtype=images.dtype, device=images.device)
indices = torch.arange(N)
colored_images[indices, colors, :, :] = images
return colored_images
[docs]
def affine_transform(images, degrees, scales, batch_size=512):
"""
Apply affine transformations to a batch of images.
Applies rotation and scaling transformations to each image.
Args:
images: Tensor of shape (N, H, W) or (N, 3, H, W).
degrees: Tensor of shape (N) containing rotation degrees.
scales: Tensor of shape (N) containing scaling factors.
batch_size: Number of images to process at once (default: 512).
Returns:
Tensor: Transformed images with same shape as input.
"""
if degrees is None:
logger.warning("Degrees for affine transformation of images not provided, setting to 0.")
degrees = torch.zeros(images.shape[0], device=images.device)
if scales is None:
logger.warning("Scales for affine transformation of images not provided, setting to 1.")
scales = torch.ones(images.shape[0], device=images.device)
N = images.shape[0]
if images.dim() == 3:
images = images.unsqueeze(1) # (N, H, W) -> (N, 1, H, W)
for i in range(0, N, batch_size):
imgs = images[i:i+batch_size]
degs = degrees[i:i+batch_size]
scs = scales[i:i+batch_size]
transformed = torch.stack([
v2.RandomAffine(degrees=(deg.item(), deg.item()), scale=(sc.item(), sc.item()))(img)
for img, deg, sc in zip(imgs, degs, scs)
])
images[i:i+batch_size] = transformed
return images
[docs]
def transform_images(images, transformations, colors=None, degrees=None, scales=None):
"""
Apply a sequence of transformations to a batch of images.
Args:
images: Tensor of shape [N, H, W] or [N, 3, H, W].
transformations: List of transformation names (e.g., ['colorize', 'affine']).
colors: Optional color labels for colorization.
degrees: Optional rotation degrees for affine transform.
scales: Optional scaling factors for affine transform.
Returns:
Tensor: Transformed images.
"""
for t in transformations:
if t == 'colorize':
if colors is None:
raise ValueError("Colors must be provided for colorize.")
images = colorize(images, colors)
elif t in ['affine']:
images = affine_transform(images, degrees=degrees, scales=scales)
else:
raise ValueError(f"Unknown transformation: {t}")
return images
[docs]
def assign_random_values(concept, random_prob=[0.5, 0.5], values = [0,1]):
"""Create a vector of random values for each sample in concepts.
Args:
concepts: Tensor of shape (N) containing concept values (e.g. digit labels 0-9).
random_prob: List of probabilities for each value.
values: List of output values corresponding to each probability.
Returns:
outputs: Tensor of shape (N) containing final values.
"""
N = len(concept)
# checks on concept
assert len(concept.shape) == 1, "concepts must be a 1D tensor."
# checks on random_prob
assert len(random_prob) > 0, "random_prob must not be empty."
assert len(random_prob) == len(values), "random_prob must have the same length as values."
assert all(0.0 <= p <= 1.0 for p in random_prob), "random_prob must be between 0 and 1."
assert abs(sum(random_prob) - 1.0) < 1e-6, "random_prob must sum to 1."
# checks on values
assert len(values) > 0, "values must not be empty."
assert len(values) == len(set(values)), "values must be unique."
probs = torch.tensor(random_prob, device=concept.device)
outputs = torch.multinomial(probs, N, replacement=True)
outputs_unique = torch.unique(outputs)
outputs_unique = sorted(outputs_unique)
mapping = {outputs_unique[i].item(): values[i] for i in range(len(outputs_unique))}
outputs= torch.tensor([mapping[i.item()] for i in outputs], device=concept.device)
return outputs
[docs]
def assign_values_based_on_intervals(concept, intervals, values):
"""Create a vector of values (0 or 1) for each sample in concepts based on intervals given.
If a concept value belongs to interval[i], it gets an output value randomly chosen among values[i].
Args:
concept: Tensor of shape (N) containing concept values (e.g. digit labels 0-9).
intervals: List of lists, each inner list contains the values defining an interval.
values: List of lists of output values corresponding to each interval.
Returns:
outputs: Tensor of shape (N) containing final values.
"""
N = len(concept)
# checks on ceoncept
assert len(concept.shape) == 1, "concepts must be a 1D tensor."
# checks on intervals
assert len(intervals) == len(values), "intervals and values must have the same length."
all_interval_values = [item for sublist in intervals for item in sublist]
assert len(all_interval_values) == len(set(all_interval_values)), "input intervals must not overlap."
assert all(len(d) > 0 for d in intervals), "each entry in intervals must contain at least one value."
# checks on values
assert all(len(v) > 0 for v in values), "each entry in values must contain at least one value."
outputs = torch.zeros_like(concept)
# create mask for each interval
for i, d in enumerate(intervals):
mask = torch.isin(concept, torch.tensor(d))
outputs[mask] = i + 1
# output must be a random value chosen among values[i] for each value i of the mask
outputs_unique = torch.unique(outputs)
outputs_unique = sorted(outputs_unique)
mapping = {outputs_unique[i].item(): values[i] for i in range(len(outputs_unique))}
outputs = torch.tensor([random.choice(mapping[i.item()]) for i in outputs], device=concept.device)
return outputs
[docs]
def colorize_and_transform(data, targets, training_percentage=0.8, test_percentage=0.2, training_mode=['random'], test_mode=['random'], training_kwargs=[{}], test_kwargs=[{}]):
"""Colorize and transform MNIST images based on specified coloring scheme.
The coloring scheme is defined differently for training and test data.
It can contain parameters for coloring, scale and rotating images.
Args:
data: Tensor of shape (N, 28, 28) containing grayscale MNIST images.
targets: Tensor of shape (N) containing target values (0-9).
training_percentage: Percentage of data to color for training.
test_percentage: Percentage of data to color for testing.
training_mode: List of coloring modes for training data. Options are 'random' and '
test_mode: List of coloring modes for test data. Options are 'random' and 'digits'.
training_kwargs: List of dictionaries containing additional arguments for each training mode.
test_kwargs: List of dictionaries containing additional arguments for each test mode.
Returns:
input: Tensor of shape (N, 3, 28, 28) containing colorized and/or transformed images.
concepts: Dictionary containing values of the parameters used for coloring and transformations (e.g., colors, scales, degrees).
targets: Tensor of shape (N) containing target values (0-9).
coloring_mode: List of strings indicating the coloring mode used for each sample ('training' or 'test').
Note: data and targets are shuffled before applying the coloring scheme.
"""
percentages = {"training": training_percentage, "test": test_percentage}
mode = {"training": training_mode, "test": test_mode}
kwargs = {"training": training_kwargs, "test": test_kwargs}
assert abs(sum(percentages.values()) - 1.0) < 1e-6, "training_percentage and test_percentage must sum to 1."
# check modality, if training_mode or test mode contain "additional_concepts"
clothing_present = False
if "additional_concepts_custom" in training_mode or "additional_concepts_custom" in test_mode:
concepts_used_training = kwargs.get("training", [{}])[0].get("concepts_used", [])
concepts_used_test = kwargs.get("test", [{}])[0].get("concepts_used", [])
if "clothing" in kwargs.get("training", [{}])[0].get("concepts_used", []) or "clothing" in kwargs.get("test", [{}])[0].get("concepts_used", []):
clothing_present = True
concepts_used_training = [c for c in concepts_used_training if c != "clothing"]
concepts_used_test = [c for c in concepts_used_test if c != "clothing"]
assert concepts_used_training == concepts_used_test, "Except for 'clothing', the concepts used must be the same in training and test."
else:
assert concepts_used_training == concepts_used_test, "Concepts used must be the same in training and test."
color_mapping = {'red': 0, 'green': 1, 'blue': 2}
N = data.shape[0]
indices = torch.randperm(N)
embeddings = torch.zeros((N, 3, data.shape[1], data.shape[2]), dtype=data.dtype)
concepts = {}
coloring_mode = ["" for _ in range(N)]
# shuffle data and targets accordingly
data = data[indices]
targets = targets[indices]
start_idx = 0
for split, perc, m, kw in zip(percentages.keys(), percentages.values(), mode.values(), kwargs.values()):
m = m[0]
kw = kw[0]
n_samples = int(perc * N)
if split == "test": # last color takes the rest
end_idx = N
else:
end_idx = start_idx + n_samples
selected_data = data[start_idx:end_idx]
selected_targets = targets[start_idx:end_idx]
if m == 'random':
# check keys of kw are exactly the ones expected
expected_keys = ['random_prob', 'values']
if set(kw.keys()) != set(expected_keys):
raise ValueError(f"random coloring requires the following keys in kwargs: {expected_keys}")
# load values from kw
prob_mod = kw.get('random_prob')
colors = kw.get('values')
# checks on 'random_prob'
assert isinstance(prob_mod, list), "random_prob must be a list."
# checks on 'values'
assert isinstance(colors, list), "values must be a list."
if not all(v in color_mapping for v in colors):
raise ValueError(f"All values must be one of {list(color_mapping.keys())}.")
assert len(colors) == len(set(colors)), "colors must not repeat."
# transform prob_mod if needed
if prob_mod[0] == 'uniform':
random_prob = [1.0 / (len(colors))] * (len(colors))
else:
random_prob = prob_mod
# calculate concept values and transform images accordingly
numeric_colors = [color_mapping[v] for v in colors]
random_colors = assign_random_values(selected_targets, random_prob=random_prob, values=numeric_colors)
colored_data = transform_images(selected_data, transformations=["colorize"], colors=random_colors)
selected_concepts = {'colors': random_colors}
elif m == 'intervals':
# check keys of kw are exactly the ones expected
expected_keys = ['intervals', 'values']
if set(kw.keys()) != set(expected_keys):
raise ValueError(f"intervals coloring requires the following keys in kwargs: {expected_keys}")
# load values from kw
interval_values = kw.get('intervals')
colors = kw.get('values')
# checks on 'intervals'
assert all(isinstance(v, list) for v in interval_values), "each entry in intervals must be a list."
assert len(interval_values) == len(colors), "intervals and values must have the same length."
all_interval_values = [item for sublist in interval_values for item in sublist]
unique_targets = torch.unique(selected_targets).tolist()
assert set(all_interval_values) == set(unique_targets), f"intervals must cover all target values, i.e.: {unique_targets}"
assert set(all_interval_values).issubset(set(range(10))), "interval values must be between 0 and 9."
# checks on 'values'
assert all(isinstance(v, list) for v in colors), "each entry in colors must be a list."
all_colors_values = [item for sublist in colors for item in sublist]
if not all(v in color_mapping for v in all_colors_values):
raise ValueError(f"All values must be one of {list(color_mapping.keys())}.")
# calculate concept values and transform images accordingly
numeric_colors = [[color_mapping[v] for v in sublist] for sublist in colors]
interval_colors = assign_values_based_on_intervals(selected_targets, intervals=interval_values, values=numeric_colors)
colored_data = transform_images(selected_data, transformations=["colorize"], colors=interval_colors)
selected_concepts = {'colors': interval_colors}
elif m == 'additional_concepts_custom':
# check keys of kw are exactly the ones expected
expected_keys = ['concepts_used', 'values']
if set(kw.keys()) != set(expected_keys):
raise ValueError(f"additional_concepts_custom coloring requires the following keys in kwargs: {expected_keys}")
# load values from kw
concepts_used = kw.get('concepts_used')
values = kw.get('values')
# checks on 'concepts_used'
assert isinstance(concepts_used, list), "concepts_used must be a list."
#assert len(concepts_used) == 3, "There must be 3 concepts used."
assert len(concepts_used) == len(values), "concepts_used and values must have the same length."
assert 'colors' in concepts_used, "concepts_used must contain 'color'"
# checks on 'values'
assert all(isinstance(v, list) for v in values), "each entry in values must be a list."
lengths = [len(v) for v in values]
assert all(l == lengths[0] for l in lengths), "each entry in values must have the same length."
# if "clothing" is in concept_used, check all values are present
if 'clothing' in concepts_used:
# it must be in the first position
assert concepts_used.index('clothing') == 0, "If 'clothing' is used, it must be the first concept."
clothing_values = values[concepts_used.index('clothing')]
all_clothing = set(range(10))
provided_clothing = set([item for sublist in clothing_values for item in sublist])
assert all_clothing.issubset(provided_clothing), "All clothing values (0-9) must be present in clothing values."
assert provided_clothing.issubset(all_clothing), "Clothing values must be between 0 and 9."
# calculate concept values and transform images accordingly
idx_color = concepts_used.index('colors')
values[idx_color] = [[color_mapping[c] for c in sublist] for sublist in values[idx_color]]
if concepts_used[0] !="clothing":
# if concept 0 is not clothing, assign random values to samples from values[0]
concept_0_values = [item for sublist in values[0] for item in sublist]
random_prob = [1.0 / len(concept_0_values)] * (len(concept_0_values))
concept_0 = assign_random_values(selected_targets, random_prob = random_prob, values = concept_0_values)
else:
concept_0 = selected_targets
selected_concepts = {}
selected_concepts[concepts_used[0]] = concept_0
for i in range(1,len(concepts_used)):
selected_concepts[concepts_used[i]] = assign_values_based_on_intervals(selected_concepts[concepts_used[i-1]],
intervals = values[i-1],
values = values[i])
if 'clothing' in selected_concepts:
del selected_concepts['clothing']
idx_scale = concepts_used.index('scales') if 'scales' in concepts_used else None
idx_degree = concepts_used.index('degrees') if 'degrees' in concepts_used else None
colored_data = transform_images(selected_data,
transformations=["colorize", "affine"],
colors= selected_concepts[concepts_used[idx_color]],
degrees= selected_concepts[concepts_used[idx_degree]] if idx_degree is not None else None,
scales= selected_concepts[concepts_used[idx_scale]] if idx_scale is not None else None)
elif m == 'additional_concepts_random':
# check keys of kw are exactly the ones expected
expected_keys = ['concepts_used', 'values', 'random_prob']
if set(kw.keys()) != set(expected_keys):
raise ValueError(f"additional_concepts_random coloring requires the following keys in kwargs: {expected_keys}")
# load values from kw
concepts_used = kw.get('concepts_used', [])
values = kw.get('values', [])
prob_mod = kw.get('random_prob')
# checks on 'concepts_used'
assert isinstance(concepts_used, list), "concepts_used must be a list."
assert len(concepts_used) == len(values), "concepts_used and values must have the same length."
assert len(concepts_used) == len(prob_mod), "concepts_used and random_prob must have the same length."
assert 'colors' in concepts_used, "concepts_used must contain 'colors'"
assert 'clothing' not in concepts_used, "'clothing' cannot be used in additional_concepts_random coloring."
# checks on 'values'
assert all(isinstance(v, list) for v in values), "each entry in values must be a list."
# checks on 'random_prob'
assert all(isinstance(v, list) for v in prob_mod), "each entry in random_prob must be a list."
# transform prob_mod if needed
random_prob = {}
for i in range(len(prob_mod)):
random_prob[i] = []
if prob_mod[i][0] == 'uniform':
random_prob[i] = [1.0 / (len(values[i]))] * (len(values[i]))
else:
random_prob[i] = prob_mod[i]
# calculate concept values and transform images accordingly
idx_color = concepts_used.index('colors')
values[idx_color] = [color_mapping[c] for c in values[idx_color]]
selected_concepts = {}
for i in range(len(concepts_used)):
selected_concepts[concepts_used[i]] = assign_random_values(selected_targets,
random_prob = random_prob[i],
values = values[i])
idx_scale = concepts_used.index('scales') if 'scales' in concepts_used else None
idx_degree = concepts_used.index('degrees') if 'degrees' in concepts_used else None
colored_data = transform_images(selected_data,
transformations=["colorize", "affine"],
colors= selected_concepts[concepts_used[idx_color]],
degrees= selected_concepts[concepts_used[idx_degree]] if idx_degree is not None else None,
scales= selected_concepts[concepts_used[idx_scale]] if idx_scale is not None else None)
else:
raise ValueError(f"Unknown coloring mode: {m}")
# assign to the main tensors and dict
embeddings[start_idx:end_idx] = colored_data
for k, v in selected_concepts.items():
if k not in concepts:
concepts[k] = torch.zeros(N, dtype=v.dtype)
concepts[k][start_idx:end_idx] = v
coloring_mode[start_idx:end_idx] = [split] * selected_data.shape[0]
start_idx = end_idx
return embeddings, concepts, targets, coloring_mode