Skip to content

Commit

Permalink
add ability to specify number of resnet blocks in discrete vae
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 10, 2021
1 parent d5004d6 commit c5c56e2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from dalle_pytorch import DiscreteVAE
vae = DiscreteVAE(
image_size = 256,
num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_resnet_blocks = 1, # number of residual blocks at each layer
num_tokens = 1024, # number of visual tokens. iGPT had 512, so probably should have more
codebook_dim = 512, # codebook dimension
hidden_dim = 64, # hidden dimension
Expand All @@ -46,6 +47,7 @@ from dalle_pytorch import DiscreteVAE, DALLE
vae = DiscreteVAE(
image_size = 256,
num_layers = 3,
num_resnet_blocks = 1,
num_tokens = 1024,
codebook_dim = 512,
hidden_dim = 64,
Expand Down
75 changes: 55 additions & 20 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def always(val):
def inner(*args, **kwargs):
return val
return inner

def is_empty(t):
return t.nelement() == 0

Expand Down Expand Up @@ -43,14 +48,46 @@ def top_k(logits, thres = 0.5):

# discrete vae class

def ConvBlock(chan_in, chan_out):
return nn.Sequential(
nn.Conv2d(chan_in, chan_out, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan_out, chan_out, 3, padding = 1),
nn.ReLU()
)

class ResBlock(nn.Module):
def __init__(
self,
chan_in,
chan_out,
num_blocks = 1,
upsample = False
):
super().__init__()
self.upsample = upsample
conv_kls = nn.ConvTranspose2d if upsample else nn.Conv2d
self.res = conv_kls(chan_in, chan_out, 1, stride = 2) if (num_blocks > 0 and chan_in != chan_out) else always(0)

self.net = nn.Sequential(*[
nn.Sequential(conv_kls(chan_in, chan_out, 4, stride = 2, padding = 1), nn.ReLU()),
*[ConvBlock(chan_out, chan_out) for _ in range(num_blocks)]
])

def forward(self, x):
out = self.net(x)
res_kwargs = {'output_size': out.shape[2:]} if self.upsample else {}
return out + self.res(x, **res_kwargs)

class DiscreteVAE(nn.Module):
def __init__(
self,
image_size = 256,
num_tokens = 512,
codebook_dim = 512,
hidden_dim = 64,
num_layers = 3,
num_resnet_blocks = 1,
hidden_dim = 64,
channels = 3,
temperature = 0.9
):
Expand All @@ -66,28 +103,26 @@ def __init__(

hdim = hidden_dim

encoder_layers = []
decoder_layers = []
for i in range(num_layers):
is_first = i == 0
enc_chans = [hidden_dim] * num_layers
dec_chans = reversed(enc_chans)

enc_in = channels if is_first else hdim
encoder_layers += [
nn.Conv2d(enc_in, hdim, 4, stride = 2, padding = 1),
nn.ReLU(),
]
dec_in = codebook_dim if is_first else hdim
decoder_layers += [
nn.ConvTranspose2d(dec_in, hdim, 4, stride = 2, padding = 1),
nn.ReLU(),
]
enc_chans = [channels, *enc_chans]
dec_chans = [codebook_dim, *dec_chans]

enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))

enc_layers = []
dec_layers = []

for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(ResBlock(enc_in, enc_out, num_blocks = num_resnet_blocks))
dec_layers.append(ResBlock(dec_in, dec_out, num_blocks = num_resnet_blocks, upsample = True))

encoder_layers.append(nn.Conv2d(hdim, num_tokens, 1))
decoder_layers.append(nn.Conv2d(hdim, channels, 1))
enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))

self.encoder = nn.Sequential(*encoder_layers)
self.decoder = nn.Sequential(*decoder_layers)
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)

@torch.no_grad()
def get_codebook_indices(self, images):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.30',
version = '0.0.31',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c5c56e2

Please sign in to comment.