Skip to content

Latest commit

 

History

History
21 lines (15 loc) · 1.09 KB

README.md

File metadata and controls

21 lines (15 loc) · 1.09 KB

PytorchDatasetCaching

This repository provides a wrapper for any pytorch dataset that allows caching sample augmentations. This speeds up training when augmentations are compute intensive -such as 3D Images or extensive augmentations.

What it does: For each sample, it will apply and save a settable number of transformation to a cache directory. After the number has been reached, it will radomly load one of the saved samples, without applying any other transformations. This would limit the size of your dataset, but if you set a large enough number of transformations than this should not be a problem.

How to use:

Clone this repository into your project

git clone https://github.com/hhroberthdaniel/PytorchDatasetCaching.git

or just copy the code from cache_dataset.py

trainset = None # set this to your own Pytorch Dataset
cached_trainset = CacheDataset(trainset, augmentations_per_sample=AUGMENTATIONS_PER_SAMPLE, cache_dir="./tmp")
trainloader = torch.utils.data.DataLoader(cached_trainset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)