Skip to content

Commit

Permalink
make sure to freeze VAE parameters after being passed into DALL-E
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 4, 2021
1 parent a0e8ea4 commit 19f4212
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
5 changes: 5 additions & 0 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]

def set_requires_grad(model, value):
for param in model.parameters():
param.requires_grad = value

def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
Expand Down Expand Up @@ -347,6 +351,7 @@ def __init__(
self.total_seq_len = seq_len

self.vae = vae
set_requires_grad(self.vae, False) # freeze VAE from being trained

self.transformer = Transformer(
dim = dim,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.11.2',
version = '0.11.3',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down
6 changes: 4 additions & 2 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
def exists(val):
return val is not None

def get_trainable_params(model):
return [params for params in model.parameters() if params.requires_grad]

# constants

Expand Down Expand Up @@ -229,7 +231,7 @@ def group_weight(model):

# optimizer

opt = Adam(dalle.parameters(), lr=LEARNING_RATE)
opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)

if LR_DECAY:
scheduler = ReduceLROnPlateau(
Expand Down Expand Up @@ -272,7 +274,7 @@ def group_weight(model):
args=args,
model=dalle,
optimizer=opt,
model_parameters=dalle.parameters(),
model_parameters=get_trainable_params(dalle),
training_data=ds if using_deepspeed else dl,
lr_scheduler=scheduler if LR_DECAY else None,
config_params=deepspeed_config,
Expand Down

0 comments on commit 19f4212

Please sign in to comment.