Skip to content

Commit

Permalink
fix generation and further cleanup by moving generate method into Dal…
Browse files Browse the repository at this point in the history
…le class
  • Loading branch information
lucidrains committed Jan 8, 2021
1 parent 57f6cf2 commit b1d9b44
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 51 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,8 @@ loss.backward()
Finally, to generate images

```python
from dalle_pytorch import generate_images

images = generate_images(
dalle,
dalle.generate_images(
vae = vae,
text = text,
mask = mask
Expand All @@ -132,10 +130,8 @@ images.shape # (2, 3, 256, 256)
To get the similarity scores from your trained Clipper, just do

```python
from dalle_pytorch import generate_images

images, scores = generate_images(
dalle,
images, scores = dalle.generate_images(
vae = vae,
text = text,
mask = mask,
Expand Down
1 change: 0 additions & 1 deletion dalle_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
from dalle_pytorch.dalle_pytorch import generate_images
86 changes: 43 additions & 43 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,49 +37,6 @@ def top_k(logits, thres = 0.5):
probs.scatter_(1, ind, val)
return probs

@torch.no_grad()
@eval_decorator
def generate_images(
model,
vae,
text,
clipper = None,
mask = None,
filter_thres = 0.5,
temperature = 1.
):
x = text

text_seq_len = model.text_seq_len
image_seq_len = model.image_seq_len
total_len = text_seq_len + model.image_seq_len - text.shape[1]

out = x
for _ in range(total_len):
text, image = x[:, :text_seq_len], x[:, text_seq_len:]
logits = model(text, image, mask = mask)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)

if out.shape[1] <= text_seq_len:
mask = F.pad(mask, (0, 1), value=True)

text_seq = torch.cat((x[:, :1], out[:, :(text_seq_len - 1)]), dim = 1)

img_seq = out[:, -image_seq_len:]
img_seq -= model.num_text_tokens

images = vae.decode(img_seq)

if exists(clipper):
scores = clipper(text_seq, images, return_loss = False)
return images, scores

return images

# discrete vae class

class DiscreteVAE(nn.Module):
Expand Down Expand Up @@ -304,6 +261,49 @@ def __init__(

self.register_buffer('logits_mask', logits_mask)

@torch.no_grad()
@eval_decorator
def generate_images(
self,
vae,
text,
clipper = None,
mask = None,
filter_thres = 0.5,
temperature = 1.
):
text_seq_len, image_seq_len, num_text_tokens = self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len

out = text
for cur_len in range(text.shape[1], total_len):
is_image = cur_len >= text_seq_len

text, image = out[:, :text_seq_len], out[:, text_seq_len:]

logits = self(text, image, mask = mask)[:, -1, :]

filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim = -1)
sample = torch.multinomial(probs, 1)

sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
out = torch.cat((out, sample), dim=-1)

if out.shape[1] <= text_seq_len:
mask = F.pad(mask, (0, 1), value = True)

text_seq = out[:, :text_seq_len]

img_seq = out[:, -image_seq_len:]
images = vae.decode(img_seq)

if exists(clipper):
scores = clipper(text_seq, images, return_loss = False)
return images, scores

return images

def forward(
self,
text,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.19',
version = '0.0.21',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b1d9b44

Please sign in to comment.