Source code for torch_concepts.data.datasets.traffic

"""
Synthetic traffic intersection dataset.
"""

import copy
import hashlib
import json
import logging
import numpy as np
import os
import pickle
import torch

from torch.utils.data import Dataset

from .traffic_construction.cars import AVAILABLE_CAR_COLORS
from .traffic_construction.generate_data import construct_samples, fix_seeds

STANDARD_CONCEPT_NAMES = [
    'green x-axis',
    'green y-axis',
    'ambulance seen',
    'car in intersection',
    'other cars visible',
    'selected car in north lanes',
    'selected car in east lanes',
    'selected car in south lanes',
    'selected car in west lanes',
    'green light on selected lane',
    'car in intersection perpendicular to selected car',
    'ambulance approaching perpendicular to selected car',
]

[docs] class TrafficLights(Dataset): """ Synthetic traffic dataset """
[docs] def __init__( self, root_dir="./data_cache/", split='train', concept_transform=None, class_dtype=float, img_transform=None, regenerate=False, sym_link=None, use_absolute_path=False, # Potential config values num_threads=4, n_samples=1000, seed=42, position_para_noise=50, position_perp_noise=20, error_probability=0.1, p_ambulance=0.2, min_num_cars=0, max_num_cars=7, resize_final_image=0.15, car_colors=AVAILABLE_CAR_COLORS, possible_starting_directions=[ 'north', 'east', 'south', 'west', ], thickness=100, light_scale=1.5, use_lights_sprites=False, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2, test_config_override_values=None, val_config_override_values=None, selected_concepts=None, ): self.n_samples = n_samples self.train_ratio = train_ratio self.val_ratio = val_ratio self.test_ratio = test_ratio self.seed = seed self.sym_link = sym_link train_config = dict( n_samples=n_samples, position_para_noise=position_para_noise, position_perp_noise=position_perp_noise, error_probability=error_probability, p_ambulance=p_ambulance, min_num_cars=min_num_cars, max_num_cars=max_num_cars, resize_final_image=resize_final_image, car_colors=car_colors, possible_starting_directions=possible_starting_directions, thickness=thickness, light_scale=light_scale, use_lights_sprites=use_lights_sprites, seed=self.seed, train_ratio=train_ratio, val_ratio=val_ratio, test_ratio=test_ratio, ) # standardize concept names in case they were provided if selected_concepts is not None: selected_concepts = selected_concepts[:] for idx, concept_name in enumerate(selected_concepts): if isinstance(concept_name, str): selected_concepts[idx] = STANDARD_CONCEPT_NAMES.index( concept_name ) self.selected_concepts = selected_concepts # Generate a hash for the config so that we can determine if this dataset # has been previously generated self.root_dir = root_dir if not os.path.exists(self.root_dir): logging.warning( f"Root directory '{self.root_dir}' TrafficLights does not " f"exist. We will generate it and dump the new dataset in there." ) os.makedirs(self.root_dir) has_key = copy.deepcopy(train_config) has_key['test_config_override_values'] = test_config_override_values has_key['val_config_override_values'] = val_config_override_values setup_hash = hashlib.sha256( json.dumps(has_key, sort_keys=True).encode() ).hexdigest() self.real_data_dir = os.path.join(self.root_dir, str(setup_hash)) os.makedirs(self.real_data_dir, exist_ok=True) # Generate the dataset, if needed self.records_dir = os.path.join( self.real_data_dir, f'records/', ) os.makedirs(self.records_dir, exist_ok=True) self._construct_dataset( train_config, num_threads=num_threads, test_config_override_values=test_config_override_values, val_config_override_values=val_config_override_values, use_absolute_path=use_absolute_path, regenerate=regenerate, ) # At this point, we can assume that the dataset has been fully generated! self.split = split assert split in ['train', 'test', 'val'] self.concept_names = STANDARD_CONCEPT_NAMES if self.selected_concepts is not None: self.concept_names = list( np.array(self.concept_names)[self.selected_concepts] ) self.task_names = ['continue'] self.class_dtype = class_dtype self.concept_transform = concept_transform self.split_array_map = np.load( os.path.join(self.real_data_dir, f'{split}_indices.npy') ) self.transform = img_transform img, _, _, _, _ = self[0] self.input_shape = tuple(img.shape)
def _construct_dataset( self, config, num_threads=4, test_config_override_values=None, val_config_override_values=None, use_absolute_path=False, regenerate=False, ): if os.path.exists( os.path.join(self.real_data_dir, 'completed.txt') ) and ( not regenerate ): logging.info( 'We found a dataset previously generated with the same ' 'config that has been cached.' ) logging.info( '\tIf you wish to re-generate it, please use ' 'regenerate=True.' ) return ## Fix Random Seeds fix_seeds(config.get('seed', self.seed)) ## Split the data # Compute split indices. Because the data is randomly generated (i.e., # already shuffled), we can just split the data into three by taking # contiguous but disjoint areas of it n_samples = config.get('n_samples', self.n_samples) train_size = int(np.ceil( n_samples * self.train_ratio )) train_indices = np.arange(0, train_size) val_size = int(np.ceil( n_samples * self.val_ratio )) val_indices = np.arange(train_size, train_size + val_size) test_indices = np.arange(train_size + val_size, n_samples) # Serialize them for future reference np.save( os.path.join(self.real_data_dir, 'train_indices.npy'), train_indices, ) np.save( os.path.join(self.real_data_dir, 'val_indices.npy'), val_indices, ) np.save( os.path.join(self.real_data_dir, 'test_indices.npy'), test_indices, ) with open( os.path.join(self.real_data_dir, f'train_config.pkl'), 'wb', ) as f: pickle.dump(config, f) construct_samples( config=config, records_dir=self.records_dir, indices=train_indices, dataset_name="training", num_threads=num_threads, ) val_config = copy.deepcopy(config) val_config.update(val_config_override_values or {}) with open( os.path.join(self.real_data_dir, f'val_config.pkl'), 'wb', ) as f: pickle.dump(val_config, f) construct_samples( config=val_config, records_dir=self.records_dir, indices=val_indices, dataset_name="validation", num_threads=num_threads, ) test_config = copy.deepcopy(config) test_config.update(test_config_override_values or {}) with open( os.path.join(self.real_data_dir, f'test_config.pkl'), 'wb', ) as f: pickle.dump(test_config, f) construct_samples( config=test_config, records_dir=self.records_dir, indices=test_indices, dataset_name="test", num_threads=num_threads, ) # And mark the dataset as complete with open(os.path.join(self.real_data_dir, 'completed.txt'), "w") as f: f.write("1") if self.sym_link is not None: # Make the symbolic link (overwriting a previous one if it already # exists) if os.path.islink(self.sym_link): os.remove(self.sym_link) if use_absolute_path: os.symlink(os.path.abspath(self.real_data_dir), self.sym_link) else: os.symlink(self.real_data_dir, self.sym_link) def _from_meta_to_concepts(self, sample_meta): # Concepts will be: # [0] Light color x-axis (0 if red, 1 if green) # [1] Light color y-axis (0 if red, 1 if green) # [2] Ambulance (1 if there is an ambulance in sight, 0 otherwise) # [3] Car in intersection (1 if there is a car in the intersection, # 0 otherwise) # [4] Other cars (1 if there are other cars visible anywhere, 0 # otherwise) # [5] Selected car in north lane # [6] Selected car in east lane # [7] Selected car in south lane # [8] Selected car in west lane # [9] Green light on selected lane # [10] Car perpendicular in intersection (1 if there is a car in the # intersection in the direction perpendicular to this car, 0 # otherwise) # [11] Ambulance Perpendicular (1 if the ambulance is in the # direction perpendicular to the car, 0 otherwise) c = np.array([ float( sample_meta['green'] and (sample_meta['selected_lane']['dir'] in ['east', 'west']) ), # [0] float( sample_meta['green'] and (sample_meta['selected_lane']['dir'] in ['south', 'north']) ), # [1] float(np.any( [x['ambulance'] for x in sample_meta['other_cars']] )), # [2] float(np.any([ x['in_intersection'] for x in sample_meta['other_cars'] ])), # [3] float(len(sample_meta['other_cars']) > 0), # [4] float(sample_meta['selected_lane']['idx'] == 7), # [5] float(sample_meta['selected_lane']['idx'] == 1), # [6] float(sample_meta['selected_lane']['idx'] == 3), # [7] float(sample_meta['selected_lane']['idx'] == 5), # [8] float(sample_meta['green']), # [9] float(sample_meta['perp_intersection_occupied']), # [10] float(sample_meta['perp_incoming_ambulance']), # [11] ]) if self.selected_concepts is not None: c = c[self.selected_concepts] if self.concept_transform is not None: return self.concept_transform(torch.FloatTensor(c)) return torch.FloatTensor(c) def _from_meta_to_label(self, sample_meta): y = self.class_dtype(sample_meta['action'] == 'continue') if self.class_dtype == float: y = torch.FloatTensor([y]).squeeze(-1) return y
[docs] def sample_array(self, real_idx): sample_filename = os.path.join( self.records_dir, f'sample_{real_idx}.npz' ) loaded_data = np.load(sample_filename, allow_pickle=True) img = loaded_data['img'] metadata = loaded_data['metadata'].item() img = torch.FloatTensor( # Transpose the image so that channels are first # Note: the image is already normalized so its values are within # [0, 1] np.transpose(img, [2, 0, 1]) ) if self.transform is not None: img = self.transform(img) return img, metadata
def __len__(self): return len(self.split_array_map) def __getitem__(self, idx): real_idx = self.split_array_map[idx] img, sample_meta = self.sample_array(real_idx) y = self._from_meta_to_label(sample_meta) c = self._from_meta_to_concepts(sample_meta) c_names = self.concept_names if self.concept_transform is not None: c_names = self.concept_transform(c_names) return ( img, y, c, c_names, self.task_names, )