Source code for torch_concepts.data.datasets.mnist_arithmetic

import os
import random
import torch
import numpy as np
import pandas as pd
import logging
from typing import List, Optional, Tuple
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

from ..base.dataset import ConceptDataset
from ...annotations import Annotations

logger = logging.getLogger(__name__)

CONCEPT_NAMES = ['first_digit', 'second_digit']
TASK_NAMES = ['result']


def _import_torchvision():
    """Lazily import torchvision, raising a clear error if it is not installed."""
    try:
        import torchvision as tv
        return tv
    except ImportError as exc:
        raise ImportError(
            "MNISTArithmeticDataset requires `torchvision`. "
            "Install it with: pip install torchvision"
        ) from exc


def _generate_arithmetic_data(
    img_dir: str,
    mnist_root: str,
    train: bool,
    num_samples: int,
    img_size: int,
    seed: int,
    filename_offset: int = 0,
):
    """Generate MNIST arithmetic composite images and save to disk.

    Args:
        img_dir: Directory to save generated images.
        mnist_root: Root for MNIST download.
        train: Whether to use MNIST train split.
        num_samples: Number of samples to generate.
        img_size: Output image size (square).
        seed: Random seed.
        filename_offset: Starting index for filenames (to avoid collisions).

    Returns:
        Tuple of (filenames, concepts, tasks) lists.
    """
    # Fix the seed for reproducibility
    random.seed(seed)
    np.random.seed(seed)

    tv = _import_torchvision()
    mnist = tv.datasets.MNIST(root=mnist_root, train=train, download=False, transform=None)

    # Note: MNIST digits are 28x28.
    # The composite canvas is (84x28) before resizing to (img_size, img_size).
    resize_transform = tv.transforms.Compose([
        tv.transforms.Resize((img_size, img_size)),
        tv.transforms.Grayscale(num_output_channels=3),
    ])

    # REDUCED FONT SIZE: 20-24 is ideal for a 28x28 pixel block.
    try:
        font = ImageFont.truetype("arial.ttf", 22) 
    except OSError:
        font = ImageFont.load_default()

    os.makedirs(img_dir, exist_ok=True)

    operators = ('+', '-', 'x', '/')
    operator_list = [random.choice(operators) for _ in range(num_samples)]

    filenames = []
    concepts_list = []
    tasks_list = []

    for idx in tqdm(range(num_samples), desc=f"Generating MNIST arithmetic ({'train' if train else 'test'})"):
        # Sample two digits, skip 0
        i1 = random.randint(0, len(mnist) - 1)
        i2 = random.randint(0, len(mnist) - 1)
        img1, a = mnist[i1]
        img2, b = mnist[i2]

        while a == 0 or b == 0:
            if a == 0:
                i1 = random.randint(0, len(mnist) - 1)
                img1, a = mnist[i1]
            if b == 0:
                i2 = random.randint(0, len(mnist) - 1)
                img2, b = mnist[i2]

        op = operator_list[idx]

        # Arithmetic Logic
        if op == '+':
            result = a + b
        elif op == '-':
            result = a - b
        elif op == 'x':
            result = a * b
        elif op == '/':
            result = a / b
        else:
            raise ValueError(f"Unknown operator: {op}")

        # --- COMPOSITE IMAGE GENERATION ---
        
        # 1. Create the main black canvas
        canvas = Image.new("L", (84, 28), color=0)
        canvas.paste(img1, (0, 0))

        # 2. Create the operator block
        op_canvas = Image.new("L", (28, 28), color=0)
        draw = ImageDraw.Draw(op_canvas)

        if op == '-':
            # Draw a thick rectangle: [left, top, right, bottom]
            # This makes a nice, bold minus sign that won't look like a dot
            draw.rectangle([7, 13, 21, 14], fill=255)
        else:
            # Draw +, x, or / using the font
            # anchor="mm" ensures it's perfectly centered in the 28x28 block
            draw.text((14, 14), op, fill=255, font=font, anchor="mm")

        # 3. Assemble the equation
        canvas.paste(op_canvas, (28, 0))
        canvas.paste(img2, (56, 0))

        # Apply final resize and grayscale transforms
        final_img = resize_transform(canvas)

        fname = f"sample_{filename_offset + idx}.png"
        final_img.save(os.path.join(img_dir, fname))

        filenames.append(fname)
        concepts_list.append([float(a), float(b)])
        tasks_list.append([float(result)])

    return filenames, concepts_list, tasks_list


[docs] class MNISTArithmeticDataset(ConceptDataset): """MNIST Arithmetic dataset for regression with concept annotations. Composite images of two MNIST digits with an arithmetic operator between them. The concepts are the two digit values (treated as continuous). The regression task is the arithmetic result. Images in the training/validation splits are composed from MNIST train digits, while test images are composed from MNIST test digits, ensuring no digit-level leakage between train and test. Parameters ---------- root : str, optional Root directory to store/load the dataset. Default: ``'./data/mnist_arithmetic'``. num_train_samples : int, optional Number of composite samples from MNIST train split (used for train + validation). Default: 10000. num_test_samples : int, optional Number of composite samples from MNIST test split. Default: 2000. val_size : float, optional Fraction of the train pool to use as validation. Default: 0.1. img_size : int, optional Output image size (square). Default: 224. seed : int, optional Random seed for reproducible generation. Default: 42. label_descriptions: Optional dict mapping concept names to descriptions. """
[docs] def __init__( self, root: str = None, num_train_samples: int = 10000, num_test_samples: int = 2000, val_size: float = 0.1, img_size: int = 224, seed: int = 42, concept_subset: Optional[list] = None, label_descriptions: Optional[dict] = None, ): self.num_train_samples = num_train_samples self.num_test_samples = num_test_samples self.val_size = val_size self.label_descriptions = label_descriptions self.img_size = img_size self.seed = seed self.operators = ('+', '-', 'x', '/') if root is None: root = os.path.join(os.getcwd(), 'data', 'mnist_arithmetic') self.root = root filenames, concepts, annotations, graph = self.load() super().__init__( input_data=filenames, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_subset, name="MNISTArithmeticDataset", )
@property def raw_filenames(self) -> List[str]: return [ "MNIST/raw/t10k-images-idx3-ubyte", # MNIST test images "MNIST/raw/t10k-labels-idx1-ubyte", # MNIST test labels "MNIST/raw/train-images-idx3-ubyte", # MNIST train images "MNIST/raw/train-labels-idx1-ubyte" # MNIST train labels ] @property def processed_filenames(self) -> List[str]: return [ f"filenames_Ntrain_{self.num_train_samples}_Ntest_{self.num_test_samples}_seed_{self.seed}.txt", f"concepts_Ntrain_{self.num_train_samples}_Ntest_{self.num_test_samples}_seed_{self.seed}.pt", "annotations.pt", "split_mapping.h5", ] def download(self): """setup MNIST root and trigger MNIST download.""" tv = _import_torchvision() tv.datasets.MNIST(root=self.root, train=True, download=True) tv.datasets.MNIST(root=self.root, train=False, download=True) # remove zipped raw files to save space for fname in self.raw_filenames: path = os.path.join(self.root, fname + ".gz") if os.path.exists(path): os.remove(path) raw_mnist_path = os.path.join(self.root, "MNIST/raw") logger.info(f"MNIST files downloaded to {raw_mnist_path}.") def maybe_download(self): """Download and extract the dataset if needed.""" super().maybe_download() def build(self): """Generate composite arithmetic images from both MNIST splits and save metadata.""" self.maybe_download() logger.info(f"Generating MNIST arithmetic dataset " f"(train={self.num_train_samples}, test={self.num_test_samples}, seed={self.seed})") img_dir = os.path.join(self.root_dir, "images") # Generate from MNIST train split train_filenames, train_concepts, train_tasks = _generate_arithmetic_data( img_dir=img_dir, mnist_root=self.root, train=True, num_samples=self.num_train_samples, img_size=self.img_size, seed=self.seed, filename_offset=0, ) # Generate from MNIST test split (use different seed to avoid identical operator sequence) test_filenames, test_concepts, test_tasks = _generate_arithmetic_data( img_dir=img_dir, mnist_root=self.root, train=False, num_samples=self.num_test_samples, img_size=self.img_size, seed=self.seed + 1, filename_offset=self.num_train_samples, ) # Combine all all_filenames = train_filenames + test_filenames all_concepts = train_concepts + test_concepts all_tasks = train_tasks + test_tasks cy = [] for c, t in zip(all_concepts, all_tasks): cy.append(c + t) cy = torch.tensor(cy, dtype=torch.float32) cy_names = CONCEPT_NAMES + TASK_NAMES cardinalities = tuple([1] * len(cy_names)) annotations = Annotations( labels=cy_names, cardinalities=cardinalities, types=['continuous'] * len(cy_names), ) # Build split mapping: randomly split MNIST-train pool into train/val np.random.seed(self.seed) n_val = int(self.num_train_samples * self.val_size) perm = np.random.permutation(self.num_train_samples) val_indices = set(perm[:n_val].tolist()) split_labels = [] for i in range(self.num_train_samples): split_labels.append('val' if i in val_indices else 'train') for _ in range(self.num_test_samples): split_labels.append('test') # Save os.makedirs(self.root_dir, exist_ok=True) with open(self.processed_paths[0], 'w') as f: f.write('\n'.join(all_filenames)) torch.save(cy, self.processed_paths[1]) torch.save(annotations, self.processed_paths[2]) pd.Series(split_labels).to_hdf(self.processed_paths[3], key="split_mapping", mode="w") logger.info(f"MNIST arithmetic dataset saved to {self.root_dir} " f"(train={self.num_train_samples - n_val}, val={n_val}, test={self.num_test_samples})") def load_raw(self): """Load raw processed files for the current split.""" self.maybe_build() logger.info(f"Loading MNIST arithmetic dataset from {self.root_dir}") with open(self.processed_paths[0], 'r') as f: filenames = f.read().strip().split('\n') concepts = torch.load(self.processed_paths[1], weights_only=False) annotations = torch.load(self.processed_paths[2], weights_only=False) graph = None return filenames, concepts, annotations, graph def load(self): return self.load_raw() def __getitem__(self, item): if self.embs_precomputed: x = self.input_data[item] else: filename = self.input_data[item] img_path = os.path.join(self.root_dir, "images", filename) img = Image.open(img_path).convert('RGB') x = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0 c = self.concepts[item] return { 'inputs': {'x': x}, 'concepts': {'c': c}, } @property def n_samples(self) -> int: return len(self.input_data) @property def n_features(self) -> tuple: return tuple(self[0]['inputs']['x'].shape) @property def shape(self) -> tuple: return (self.n_samples, *self.n_features)