From 0c3c84c55c59b7ad6f114a76e2b647fe7d8bcec2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 16 Jan 2021 17:48:34 -0800 Subject: [PATCH] allow for straight-through gumbel softmax --- README.md | 13 +++++++------ dalle_pytorch/dalle_pytorch.py | 8 +++++--- setup.py | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 986e8ace..961a1a96 100644 --- a/README.md +++ b/README.md @@ -28,12 +28,13 @@ from dalle_pytorch import DiscreteVAE vae = DiscreteVAE( image_size = 256, - num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map) - num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects - codebook_dim = 512, # codebook dimension - hidden_dim = 64, # hidden dimension - num_resnet_blocks = 1, # number of resnet blocks - temperature = 0.9 # gumbel softmax temperature, the lower this is, the more hard the discretization + num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map) + num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects + codebook_dim = 512, # codebook dimension + hidden_dim = 64, # hidden dimension + num_resnet_blocks = 1, # number of resnet blocks + temperature = 0.9, # gumbel softmax temperature, the lower this is, the more hard the discretization + straight_through = False # straight-through for gumbel softmax. unclear if it is better one way or the other ) images = torch.randn(4, 3, 256, 256) diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 83c5323e..4f9bd6db 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -72,7 +72,8 @@ def __init__( num_resnet_blocks = 0, hidden_dim = 64, channels = 3, - temperature = 0.9 + temperature = 0.9, + straight_through = False ): super().__init__() assert log2(image_size).is_integer(), 'image size must be a power of 2' @@ -83,8 +84,9 @@ def __init__( self.num_tokens = num_tokens self.num_layers = num_layers self.temperature = temperature + self.straight_through = straight_through self.codebook = nn.Embedding(num_tokens, codebook_dim) - + hdim = hidden_dim enc_chans = [hidden_dim] * num_layers @@ -146,7 +148,7 @@ def forward( if return_logits: return logits # return logits for getting hard image indices for DALL-E training - soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1) + soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1, hard = self.straight_through) sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight) out = self.decoder(sampled) diff --git a/setup.py b/setup.py index a221bb84..6090c28f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.39', + version = '0.0.40', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',