Skip to content

Commit

Permalink
add autoregressive gMLP layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 19, 2021
1 parent c4330e3 commit bdb0428
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
11 changes: 10 additions & 1 deletion dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention

from g_mlp_pytorch import gMLPBlock

# helpers

def exists(val):
Expand Down Expand Up @@ -105,11 +107,18 @@ def __init__(
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size)
elif attn_type == 'conv_like':
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size)
elif attn_type == 'mlp':
attn_class = partial(gMLPBlock, seq_len = seq_len)
else:
raise ValueError(f'attention type "{attn_type}" is not valid')

if attn_type != 'mlp':
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
else:
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)

layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
]))

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.11.3',
version = '0.12.0',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand All @@ -21,6 +21,7 @@
'DALL-E',
'einops>=0.3',
'ftfy',
'g-mlp-pytorch',
'pillow',
'regex',
'taming-transformers',
Expand Down
2 changes: 1 addition & 1 deletion train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
parser.add_argument('--hug', dest='hug', action='store_true')

parser.add_argument('--bpe_path', type=str,
help='path to your huggingface BPE json file')
help='path to your BPE json file')

parser.add_argument('--fp16', action='store_true',
help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.')
Expand Down

0 comments on commit bdb0428

Please sign in to comment.