From 587197afd99aefd0f9d1b040182812cd1a378e45 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 8 Jan 2021 13:11:54 -0800 Subject: [PATCH] release development version where vae and decoder is trained end-to-end --- README.md | 32 +++++++------------------------- dalle_pytorch/dalle_pytorch.py | 21 +++++++-------------- setup.py | 4 ++-- 3 files changed, 16 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 95a3a9a3..34f9ef20 100644 --- a/README.md +++ b/README.md @@ -6,35 +6,15 @@ Implementation / replication of DALL-E ## Install +Development branch + ```bash -$ pip install dalle-pytorch +$ pip install dalle-pytorch-dev ``` ## Usage -Train VAE - -```python -import torch -from dalle_pytorch import DiscreteVAE - -vae = DiscreteVAE( - image_size = 256, - num_layers = 3, - num_tokens = 1024, - codebook_dim = 512, - hidden_dim = 64 -) - -images = torch.randn(4, 3, 256, 256) - -loss = vae(images, return_recon_loss = True) -loss.backward() - -# train with a lot of data to learn a good codebook -``` - -Train DALL-E with pretrained VAE from above +Train DALL-E with VAE end-to-end ```python import torch @@ -53,7 +33,8 @@ dalle = DALLE( num_text_tokens = 10000, # vocab size for text text_seq_len = 256, # text sequence length depth = 6, # should be 64 - heads = 8 + heads = 8, + vae_loss_coef = 1. # multiplier for vae reconstruction loss ) text = torch.randint(0, 10000, (4, 256)) @@ -113,6 +94,7 @@ images.shape # (2, 3, 256, 256) Or you can just use the official CLIP model to rank the images from DALL-E +``` ## Citations ```bibtex diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index fed7a321..cdf247ff 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -222,7 +222,6 @@ def __init__( text_seq_len = 256, depth = 6, heads = 8, - train_vae = True, # leave uncertainty for someone to explore vae_loss_coef = 1. ): super().__init__() @@ -249,7 +248,6 @@ def __init__( self.vae = vae self.vae_loss_coef = vae_loss_coef - self.train_vae = train_vae if exists(self.vae): self.vae = vae self.image_emb = vae.codebook @@ -337,14 +335,10 @@ def forward( is_raw_image = len(image.shape) == 4 if is_raw_image: - if self.train_vae: - orig_image = image - codebook_emb = self.vae(image, return_soft_embeddings = True) - image_token_emb = rearrange(codebook_emb, 'b c h w -> b (h w) c') - image = self.vae.get_codebook_indices(image) - else: - image = self.vae.get_codebook_indices(image) - image_token_emb = self.image_emb(image) + orig_image = image + codebook_emb = self.vae(image, return_soft_embeddings = True) + image_token_emb = rearrange(codebook_emb, 'b c h w -> b (h w) c') + image = self.vae.get_codebook_indices(image) else: image_token_emb = self.image_emb(image) @@ -375,9 +369,8 @@ def forward( labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS loss = F.cross_entropy(logits.transpose(1, 2), labels[:, 1:]) - if self.train_vae: - recon_img = self.vae.decoder(codebook_emb) - vae_loss = F.mse_loss(recon_img, orig_image) * self.vae_loss_coef - loss = loss + vae_loss + recon_img = self.vae.decoder(codebook_emb) + vae_loss = F.mse_loss(recon_img, orig_image) * self.vae_loss_coef + loss = loss + vae_loss return loss diff --git a/setup.py b/setup.py index 409e382d..8fd80311 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,9 @@ from setuptools import setup, find_packages setup( - name = 'dalle-pytorch', + name = 'dalle-pytorch-dev', packages = find_packages(), - version = '0.0.25', + version = '0.0.1', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',