Skip to content

Implementation of Real NVP in PyTorch

Notifications You must be signed in to change notification settings

hsaghir/real-nvp

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

43 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Real NVP in PyTorch

Implementation of Real NVP in PyTorch. Based on the paper:

Density estimation using Real NVP
Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio
arXiv:1605.08803

Training script and hyperparameters designed to match the CIFAR-10 experiments described in Section 4.1 of the paper.

Usage

Environment Setup

  1. Make sure you have Anaconda or Miniconda installed.
  2. Clone repo with git clone https://github.com/chrischute/real-nvp.git rnvp.
  3. Go into the cloned repo: cd rnvp.
  4. Create the environment: conda env create -f environment.yml.
  5. Activate the environment: source activate rnvp.

Train

  1. Make sure you've created and activated the conda environment as described above.
  2. Run python train.py -h to see options.
  3. Run python train.py [FLAGS] to train. E.g., run python train.py for the default configuration, or run python train.py --gpu_ids=[0,1] --batch_size=128 to run on 2 GPUs instead of the default of 1 GPU.
  4. At the end of each epoch, samples from the model will be saved to samples/epoch_N.png, where N is the epoch number.

One epoch takes about 4 minutes when using the default arguments and running on an NVIDIA Titan Xp card.

Samples

Epoch 5

Samples at Epoch 5

Epoch 10

Samples at Epoch 10

Epoch 15

Samples at Epoch 15

Epoch 20

Samples at Epoch 20

Epoch 25

Samples at Epoch 25

Results

Bits per Dimension

Epoch Train Valid
5 3.97 3.98
10 3.76 3.76
15 3.69 3.74
20 3.65 3.70
25 3.62 3.74

About

Implementation of Real NVP in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%