Source code for torch_concepts.data.io
"""
Input/output utilities for data handling.
This module provides utilities for downloading, extracting, and saving/loading
data files, including support for zip/tar archives and pickle files.
"""
import os
import pickle
import tarfile
import urllib.request
import zipfile
import logging
from typing import Any, Optional
from tqdm import tqdm
logger = logging.getLogger(__name__)
[docs]
def save_pickle(obj: Any, filename: str) -> str:
"""
Save object to file as pickle.
Args:
obj: Object to be saved.
filename: Where to save the file.
Returns:
str: The absolute path to the saved pickle.
"""
abspath = os.path.abspath(filename)
directory = os.path.dirname(abspath)
os.makedirs(directory, exist_ok=True)
with open(abspath, 'wb') as fp:
pickle.dump(obj, fp)
return abspath
[docs]
def load_pickle(filename: str) -> Any:
"""
Load object from pickle file.
Args:
filename: The absolute path to the saved pickle.
Returns:
Any: The loaded object.
"""
with open(filename, 'rb') as fp:
data = pickle.load(fp)
return data
[docs]
class DownloadProgressBar(tqdm):
"""
Progress bar for file downloads.
Extends tqdm to show download progress with file size information.
Adapted from https://stackoverflow.com/a/53877507
"""
[docs]
def update_to(self, b=1, bsize=1, tsize=None):
"""
Update progress bar based on download progress.
Args:
b: Number of blocks transferred so far (default: 1).
bsize: Size of each block in bytes (default: 1).
tsize: Total size in blocks (default: None).
"""
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
[docs]
def download_url(url: str,
folder: str,
filename: Optional[str] = None,
verbose: bool = True):
r"""Downloads the content of an URL to a specific folder.
Args:
url (string): The url.
folder (string): The folder.
filename (string, optional): The filename. If :obj:`None`, inferred from
url.
verbose (bool, optional): If :obj:`False`, will not show progress bars.
(default: :obj:`True`)
"""
if filename is None:
filename = url.rpartition('/')[2].split('?')[0]
path = os.path.join(folder, filename)
if os.path.exists(path):
logger.info(f'Using existing file {filename}')
return path
logger.info(f'Downloading {url}')
os.makedirs(folder, exist_ok=True)
# From https://stackoverflow.com/a/53877507
with DownloadProgressBar(unit='B',
unit_scale=True,
miniters=1,
desc=url.split('/')[-1],
disable=not verbose) as t:
urllib.request.urlretrieve(url, filename=path, reporthook=t.update_to)
return path