From c32f7249d50e2eb073feff18963ef92a4eea3ac3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 7 Jan 2021 14:22:27 -0800 Subject: [PATCH] make sure last token can still predict eos --- dalle_pytorch/dalle_pytorch.py | 5 +++-- setup.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 6f225ec4..d5d08f65 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -256,6 +256,7 @@ def __init__( self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len + seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS self.total_tokens = total_tokens @@ -271,7 +272,7 @@ def __init__( nn.Linear(dim, self.total_tokens), ) - seq_range = torch.arange(text_seq_len + image_seq_len) + seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') @@ -280,7 +281,7 @@ def __init__( logits_mask = ( ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | - (logits_range >= (total_tokens - 1)) + ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) ) self.register_buffer('logits_mask', logits_mask) diff --git a/setup.py b/setup.py index e9b98bdb..eaace065 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.15', + version = '0.0.16', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',