Skip to content

Commit

Permalink
release development version where vae and decoder is trained end-to-end
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 8, 2021
1 parent 0c75b48 commit 587197a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 41 deletions.
32 changes: 7 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,15 @@ Implementation / replication of <a href="https://openai.com/blog/dall-e/">DALL-E

## Install

Development branch

```bash
$ pip install dalle-pytorch
$ pip install dalle-pytorch-dev
```

## Usage

Train VAE

```python
import torch
from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
image_size = 256,
num_layers = 3,
num_tokens = 1024,
codebook_dim = 512,
hidden_dim = 64
)

images = torch.randn(4, 3, 256, 256)

loss = vae(images, return_recon_loss = True)
loss.backward()

# train with a lot of data to learn a good codebook
```

Train DALL-E with pretrained VAE from above
Train DALL-E with VAE end-to-end

```python
import torch
Expand All @@ -53,7 +33,8 @@ dalle = DALLE(
num_text_tokens = 10000, # vocab size for text
text_seq_len = 256, # text sequence length
depth = 6, # should be 64
heads = 8
heads = 8,
vae_loss_coef = 1. # multiplier for vae reconstruction loss
)

text = torch.randint(0, 10000, (4, 256))
Expand Down Expand Up @@ -113,6 +94,7 @@ images.shape # (2, 3, 256, 256)

Or you can just use the official <a href="https://github.com/openai/CLIP">CLIP model</a> to rank the images from DALL-E

```
## Citations
```bibtex
Expand Down
21 changes: 7 additions & 14 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def __init__(
text_seq_len = 256,
depth = 6,
heads = 8,
train_vae = True, # leave uncertainty for someone to explore
vae_loss_coef = 1.
):
super().__init__()
Expand All @@ -249,7 +248,6 @@ def __init__(

self.vae = vae
self.vae_loss_coef = vae_loss_coef
self.train_vae = train_vae
if exists(self.vae):
self.vae = vae
self.image_emb = vae.codebook
Expand Down Expand Up @@ -337,14 +335,10 @@ def forward(
is_raw_image = len(image.shape) == 4

if is_raw_image:
if self.train_vae:
orig_image = image
codebook_emb = self.vae(image, return_soft_embeddings = True)
image_token_emb = rearrange(codebook_emb, 'b c h w -> b (h w) c')
image = self.vae.get_codebook_indices(image)
else:
image = self.vae.get_codebook_indices(image)
image_token_emb = self.image_emb(image)
orig_image = image
codebook_emb = self.vae(image, return_soft_embeddings = True)
image_token_emb = rearrange(codebook_emb, 'b c h w -> b (h w) c')
image = self.vae.get_codebook_indices(image)
else:
image_token_emb = self.image_emb(image)

Expand Down Expand Up @@ -375,9 +369,8 @@ def forward(
labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS
loss = F.cross_entropy(logits.transpose(1, 2), labels[:, 1:])

if self.train_vae:
recon_img = self.vae.decoder(codebook_emb)
vae_loss = F.mse_loss(recon_img, orig_image) * self.vae_loss_coef
loss = loss + vae_loss
recon_img = self.vae.decoder(codebook_emb)
vae_loss = F.mse_loss(recon_img, orig_image) * self.vae_loss_coef
loss = loss + vae_loss

return loss
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from setuptools import setup, find_packages

setup(
name = 'dalle-pytorch',
name = 'dalle-pytorch-dev',
packages = find_packages(),
version = '0.0.25',
version = '0.0.1',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 587197a

Please sign in to comment.