diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index cf5e4555..1998a9fe 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -112,7 +112,7 @@ def __init__( temperature = 0.9, straight_through = False, kl_div_loss_weight = 0., - normalization = ((0.5,) * 3, (0.5,) * 3) + normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1)) ): super().__init__() assert log2(image_size).is_integer(), 'image size must be a power of 2' @@ -163,7 +163,7 @@ def __init__( self.kl_div_loss_weight = kl_div_loss_weight # take care of normalization within class - self.normalization = normalization + self.normalization = tuple(map(lambda t: t[:channels], normalization)) self._register_external_parameters() @@ -594,7 +594,8 @@ def forward( if is_raw_image: image_size = self.vae.image_size - assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training' + channels = self.vae.channels + assert tuple(image.shape[1:]) == (channels, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training' image = self.vae.get_codebook_indices(image) diff --git a/dalle_pytorch/loader.py b/dalle_pytorch/loader.py index fdca7ab6..68ade2ec 100644 --- a/dalle_pytorch/loader.py +++ b/dalle_pytorch/loader.py @@ -14,6 +14,7 @@ def __init__(self, image_size=128, truncate_captions=False, resize_ratio=0.75, + transparent=False, tokenizer=None, shuffle=False ): @@ -43,9 +44,12 @@ def __init__(self, self.truncate_captions = truncate_captions self.resize_ratio = resize_ratio self.tokenizer = tokenizer + + image_mode = 'RGBA' if transparent else 'RGB' + self.image_transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') - if img.mode != 'RGB' else img), + T.Lambda(lambda img: img.convert(image_mode) + if img.mode != image_mode else img), T.RandomResizedCrop(image_size, scale=(self.resize_ratio, 1.), ratio=(1., 1.)), diff --git a/setup.py b/setup.py index 46bedaa7..8ecf2b46 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '1.6.0', + version = '1.6.1', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang', diff --git a/train_dalle.py b/train_dalle.py index 1a7812c4..89cbd674 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -390,6 +390,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, + transparent=TRANSPARENT, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer,