Skip to content

Commit

Permalink
Merge pull request #280 from robvanvolt/loader-for-webdataset-included
Browse files Browse the repository at this point in the history
Added support for webdataset
  • Loading branch information
lucidrains authored Jun 16, 2021
2 parents d6107cc + 122bc51 commit 2eceb84
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 24 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ outputs/
*.pt
taming/
wandb/
dalle-ds-cp/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -90,6 +91,9 @@ ipython_config.py
# pyenv
.python-version

# Visual Studio Code
.vscode

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
Expand Down
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ Weights and Biases will allow you to monitor the temperature annealing, image re

Once you have trained a decent VAE to your satisfaction, you can move on to the next step with your model weights at `./vae.pt`.

### DALL-E
### DALL-E Training

## Training using an Image-Text-Folder

Now you just have to invoke the `./train_dalle.py` script, indicating which VAE model you would like to use, as well as the path to your folder if images and text.

Expand Down Expand Up @@ -370,6 +372,38 @@ You likely will not finish DALL-E training as quickly as you did your Discrete V
$ python train_dalle.py --dalle_path ./dalle.pt --image_text_folder /path/to/data
```

## Training using WebDataset

WebDataset files are regular .tar(.gz) files which can be streamed and used for DALLE-pytorch training.
You Just need to provide the image (first comma separated argument) and caption (second comma separated argument)
column key after the --wds argument. The ---image_text_folder points to your .tar(.gz) file instead of the datafolder.

```python
$ python train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz)
```

Distributed training with deepspeed works the same way, e.g.:

```python
$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/data.tar(.gz) --fp16 --deepspeed
```

If you have containing shards (dataset split into several .tar(.gz) files), this is also supported:

```python
$ deepspeed train_dalle.py --wds img,cap --image_text_folder /path/to/shardfolder --fp16 --deepspeed
```

You can stream the data from a http server or gloogle cloud storage like this:

```python
$ deepspeed train_dalle.py --image_text_folder "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar" --wds jpg,json --taming --truncate_captions --random_resize_crop_lower_ratio=0.8 --attn_types=full --epochs=2 --fp16 --deepspeed
```

In order to convert your image-text-folder to WebDataset format, you can make use of one of several methods.
(https://www.youtube.com/watch?v=v_PacO-3OGQ here are given 4 examples, or a little helper script which also supports splitting your dataset
into shards of .tar.gz files https://github.com/robvanvolt/DALLE-datasets/blob/main/wds_create_shards.py)

### DALL-E with OpenAI's VAE

You can now also train DALL-E without having to train the Discrete VAE at all, courtesy to their open-sourcing their model. You simply have to invoke the `train_dalle.py` script without specifying the `--vae_path`
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.12.5',
version = '0.13.0',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand All @@ -30,7 +30,8 @@
'torchvision',
'transformers',
'tqdm',
'youtokentome'
'youtokentome',
'WebDataset'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
125 changes: 104 additions & 21 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from dalle_pytorch.loader import TextImageDataset
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer

# libraries needed for webdataset support
import webdataset as wds
from torchvision import transforms as T
from PIL import Image
from io import BytesIO


# argument parsing

parser = argparse.ArgumentParser()
Expand All @@ -38,6 +45,13 @@
parser.add_argument('--image_text_folder', type=str, required=True,
help='path to your folder of images and text for learning the DALL-E')

parser.add_argument(
'--wds',
type = str,
default='',
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
)

parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
help='Captions passed in which exceed the max token length will be truncated if this is set.')

Expand Down Expand Up @@ -92,13 +106,13 @@

model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension')

model_group.add_argument('--text_seq_len', default = 256, type = int, help = 'Text sequence length')
model_group.add_argument('--text_seq_len', default = 128, type = int, help = 'Text sequence length')

model_group.add_argument('--depth', default = 2, type = int, help = 'Model depth')

model_group.add_argument('--heads', default = 8, type = int, help = 'Model number of heads')
model_group.add_argument('--heads', default = 4, type = int, help = 'Model number of heads')

model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension')
model_group.add_argument('--dim_head', default = 16, type = int, help = 'Model head dimension')

train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')

Expand All @@ -112,10 +126,6 @@

args = parser.parse_args()

# quit early if you used the wrong folder name

assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.'

# helpers

def exists(val):
Expand All @@ -137,6 +147,8 @@ def cp_path_to_dir(cp_path, tag):
return cp_dir

# constants
WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False

DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt"

Expand Down Expand Up @@ -169,6 +181,27 @@ def cp_path_to_dir(cp_path, tag):

DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'

if not ENABLE_WEBDATASET:
# quit early if you used the wrong folder name
assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.'
else:
# quit early if no tar files were found
if Path(args.image_text_folder).is_dir():
DATASET = [str(p) for p in Path(args.image_text_folder).glob("**/*") if ".tar" in str(p).lower()] # .name
assert len(DATASET) > 0, 'The directory ({}) does not contain any WebDataset/.tar files.'.format(args.image_text_folder)
print('Found {} WebDataset .tar(.gz) file(s) under given path {}!'.format(len(DATASET), args.image_text_folder))
elif ('http://' in args.image_text_folder.lower()) | ('https://' in args.image_text_folder.lower()):
DATASET = f"pipe:curl -L -s {args.image_text_folder} || true"
print('Found {} http(s) link under given path!'.format(len(DATASET), args.image_text_folder))
elif 'gs://' in args.image_text_folder.lower():
DATASET = f"pipe:gsutil cat {args.image_text_folder} || true"
print('Found {} GCS link under given path!'.format(len(DATASET), args.image_text_folder))
elif '.tar' in args.image_text_folder:
DATASET = args.image_text_folder
print('Found WebDataset .tar(.gz) file under given path {}!'.format(args.image_text_folder))
else:
raise Exception('No folder, no .tar(.gz) and no url pointing to tar files provided under {}.'.format(args.image_text_folder))

# initialize distributed backend

distr_backend = distributed_utils.set_backend_from_args(args)
Expand Down Expand Up @@ -283,19 +316,61 @@ def group_weight(model):

is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)

ds = TextImageDataset(
args.image_text_folder,
text_len=TEXT_SEQ_LEN,
image_size=IMAGE_SIZE,
resize_ratio=args.resize_ratio,
truncate_captions=args.truncate_captions,
tokenizer=tokenizer,
shuffle=is_shuffle,
)
imagepreproc = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.RandomResizedCrop(IMAGE_SIZE,
scale=(args.resize_ratio, 1.),
ratio=(1., 1.)),
T.ToTensor(),
])

def imagetransform(b):
return Image.open(BytesIO(b))

def tokenize(s):
return tokenizer.tokenize(
s.decode('utf-8'),
TEXT_SEQ_LEN,
truncate_text=args.truncate_captions).squeeze(0)

if ENABLE_WEBDATASET:
DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader

myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS
image_text_mapping = {
myimg: imagetransform,
mycap: tokenize
}
image_mapping = {
myimg: imagepreproc
}

num_batches = DATASET_SIZE // BATCH_SIZE

ds = (
wds.WebDataset(DATASET, length=num_batches)
# .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet
.map_dict(**image_text_mapping)
.map_dict(**image_mapping)
.to_tuple(mycap, myimg)
.batched(BATCH_SIZE, partial=False) # It is good to avoid partial batches when using Distributed training
)
else:
ds = TextImageDataset(
args.image_text_folder,
text_len=TEXT_SEQ_LEN,
image_size=IMAGE_SIZE,
resize_ratio=args.resize_ratio,
truncate_captions=args.truncate_captions,
tokenizer=tokenizer,
shuffle=is_shuffle,
)

assert len(ds) > 0, 'dataset is empty'
if distr_backend.is_root_worker():
print(f'{len(ds)} image-text pairs found for training')
if not ENABLE_WEBDATASET:
print(f'{len(ds)} image-text pairs found for training')

if not is_shuffle:
data_sampler = torch.utils.data.distributed.DistributedSampler(
Expand All @@ -306,10 +381,18 @@ def group_weight(model):
else:
data_sampler = None

dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)
if ENABLE_WEBDATASET:
# WebLoader for WebDataset and DeepSpeed compatibility
dl = wds.WebLoader(ds, batch_size=None, shuffle=False) # optionally add num_workers=2 (n) argument
number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
dl = dl.repeat(2).slice(number_of_batches)
dl.length = number_of_batches
else:
# Regular DataLoader for image-text-folder datasets
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)

# initialize DALL-E

# initialize DALL-E

dalle = DALLE(vae=vae, **dalle_params)
if not using_deepspeed:
Expand Down Expand Up @@ -454,13 +537,13 @@ def save_model(path, epoch=0):

# training

# Saves a checkpoint before training begins to fail early when mis-configured.
# Saves a checkpoint before training begins to fail early when mis-configured.
# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch)
for epoch in range(resume_epoch, EPOCHS):
if data_sampler:
data_sampler.set_epoch(epoch)
for i, (text, images) in enumerate(distr_dl):
for i, (text, images) in enumerate((dl if ENABLE_WEBDATASET else distr_dl)):
if i % 10 == 0 and distr_backend.is_root_worker():
t = time.time()
if args.fp16:
Expand Down

0 comments on commit 2eceb84

Please sign in to comment.