Source code for torch_concepts.data.datasets.mnist

import numpy as np
import random
import torch
from typing import Tuple

from torchvision.datasets import MNIST
from torchvision.transforms import transforms


def _colorize(image: torch.Tensor, color: str) -> torch.Tensor:
    # Create an image with 3 channels (RGB)
    colored_image = torch.zeros(3, 28, 28)
    if color == 'red':
        colored_image[0] = image  # Red channel
    elif color == 'green':
        colored_image[1] = image  # Green channel
    return colored_image


[docs] class ColorMNISTDataset(MNIST): """ The color MNIST dataset is a modified version of the MNIST dataset where each digit is colored either red or green. The concept labels are the digit and the color of the digit. The task is to predict whether the digit is even or odd. Attributes: root: The root directory where the dataset is stored. train: Whether to load the training or test split. Default is False. transform: The transformations to apply to the images. Default is None. target_transform: The transformations to apply to the target labels. Default is None. download: Whether to download the dataset if it does not exist. Default is False. random: Whether to colorize the digits randomly. Default is True. """
[docs] def __init__( self, root: str, train: bool = False, transform = None, target_transform = None, download: bool = False, random: bool = True, ): super(ColorMNISTDataset, self).__init__( root, train=train, transform=transform, target_transform=target_transform, download=download, ) self.random = random self.concept_attr_names = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'red', 'green', ] self.task_attr_names = ['even', 'odd']
def __len__(self): return len(self.data) def __getitem__(self, index): image, digit = self.data[index], int(self.targets[index]) # Colorize the image if self.random: color = 'red' if random.random() < 0.5 else 'green' else: color = 'red' if digit <= 5 else 'green' # Remove channel dimension of the grayscale image colored_image = _colorize(image.squeeze(), color) # Create the concept label concept_label = np.zeros(12) # 10 digits + 2 colors concept_label[digit] = 1 concept_label[10] = 1 if color == 'red' else 0 concept_label[11] = 1 if color == 'green' else 0 # Create the target label target_label = 1 if digit % 2 == 0 else 0 target_label = [target_label, 1 - target_label] return ( colored_image, torch.tensor(concept_label, dtype=torch.float32), torch.tensor(target_label, dtype=torch.float32), )
[docs] class MNISTAddition(MNIST): """ The MNIST addition dataset is a modified version of the MNIST dataset where each image is a concatenation of two MNIST images and the target label is the sum of the two digits. The concept label is a one-hot encoding of the two digits. Attributes: concept_names: The names of the concept labels. task_names: The names of the task labels. root: The root directory where the dataset is stored. train: Whether to load the training or test split. Default is False. transform: The transformations to apply to the images. Default is None. target_transform: The transformations to apply to the target labels. Default is None. download: Whether to download the dataset if it does not exist. Default is False. """ name = "mnist_addition" n_concepts = 20 n_tasks = 19 concept_names = [ "0_left", "1_left", "2_left", "3_left", "4_left", "5_left", "6_left", "7_left", "8_left", "9_left", "0_right", "1_right", "2_right", "3_right", "4_right", "5_right", "6_right", "7_right", "8_right", "9_right", ] task_names = [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", ] transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) input_shape = (1, 28, 56) input_dim = 28 * 56
[docs] def __init__(self, root, train, target_transform=None, download=True): super(MNISTAddition, self).__init__( root, train, self.transform, target_transform, download, )
def __getitem__( self, index, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Get the first image and target img_1, target_1 = super(MNISTAddition, self).__getitem__(index) # Get a second image and target. To get a different image, we need to # sample a different index. To do this, we pick the index from the end img_2, target_2 = super(MNISTAddition, self).__getitem__(-index) # Horizontally concat the two images and sum targets to get task label img = torch.cat((img_1, img_2), dim=2) # Sum the targets to get the task label y = target_1 + target_2 # One hot encoding of the concept label on 20 digits c = torch.zeros(20) c[target_1] = 1 c[target_2 + 10] = 1 return img, c, y
[docs] class PartialMNISTAddition(MNISTAddition): """ The partial MNIST addition dataset is a modified version of the MNIST addition dataset where the concept annotation is partial. The concept associated with the second digit is not provided. """ name = "partial_mnist_addition" n_concepts = 10 concept_names = [ "0_left", "1_left", "2_left", "3_left", "4_left", "5_left", "6_left", "7_left", "8_left", "9_left", ] def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, c, y = super(PartialMNISTAddition, self).__getitem__(index) c = c[:10] return x, c, y
[docs] class MNISTEvenOdd(MNIST): """ The MNIST even-odd dataset is a modified version of the MNIST dataset where the task is to predict whether the digit is even or odd. The concept label is a one-hot encoding of the digit. Attributes: concept_names: The names of the concept labels. task_names: The names of the task labels. root: The root directory where the dataset is stored. train: Whether to load the training or test split. Default is False. transform: The transformations to apply to the images. Default is None. target_transform: The transformations to apply to the target labels. Default is None. download: Whether to download the dataset if it does not exist. Default is False. """ name = "mnist_even_odd" n_concepts = 10 n_tasks = 2 concept_names = [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ] task_names = ["odd", "even"] transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) input_shape = (1, 28, 28) input_dim = 28 * 28
[docs] def __init__(self, root, train, target_transform=None, download=True): super(MNISTEvenOdd, self).__init__(root, train, self.transform, target_transform, download)
def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, y = super(MNISTEvenOdd, self).__getitem__(index) # One hot encoding of the concept label on 10 digits c = torch.zeros(10) c[y] = 1 # Task label is 1 if digit is even, 0 if digit is odd t = 1 if y % 2 == 0 else 0 return x, c, t