Skip to content

Commit

Permalink
make sure generation is done in eval mode, and sets it back to origin…
Browse files Browse the repository at this point in the history
…al state
  • Loading branch information
lucidrains committed Jan 8, 2021
1 parent e0456ae commit 57f6cf2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
13 changes: 12 additions & 1 deletion dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import log2, sqrt
from math import sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
Expand All @@ -18,6 +18,15 @@ def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]

def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner

# sampling helpers

def top_k(logits, thres = 0.5):
Expand All @@ -29,6 +38,7 @@ def top_k(logits, thres = 0.5):
return probs

@torch.no_grad()
@eval_decorator
def generate_images(
model,
vae,
Expand Down Expand Up @@ -110,6 +120,7 @@ def __init__(
self.num_tokens = num_tokens
self.codebook = nn.Embedding(num_tokens, dim)

@torch.no_grad()
def get_codebook_indices(self, images):
logits = self.forward(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 57f6cf2

Please sign in to comment.