diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 46ce2657..d638f93c 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -1,4 +1,4 @@ -from math import log2, sqrt +from math import sqrt import torch from torch import nn, einsum import torch.nn.functional as F @@ -18,6 +18,15 @@ def masked_mean(t, mask, dim = 1): t = t.masked_fill(~mask[:, :, None], 0.) return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + # sampling helpers def top_k(logits, thres = 0.5): @@ -29,6 +38,7 @@ def top_k(logits, thres = 0.5): return probs @torch.no_grad() +@eval_decorator def generate_images( model, vae, @@ -110,6 +120,7 @@ def __init__( self.num_tokens = num_tokens self.codebook = nn.Embedding(num_tokens, dim) + @torch.no_grad() def get_codebook_indices(self, images): logits = self.forward(images, return_logits = True) codebook_indices = logits.argmax(dim = 1).flatten(1) diff --git a/setup.py b/setup.py index 93d7b2eb..e23f55a3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.18', + version = '0.0.19', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',