diff --git a/examples/retinanet_inference_example.py b/examples/retinanet_inference_example.py index 17480b0..4d63c9f 100644 --- a/examples/retinanet_inference_example.py +++ b/examples/retinanet_inference_example.py @@ -1,24 +1,32 @@ import sys, os sys.path.insert(0, os.path.abspath('..')) -import yolk from PIL import Image - -# import miscellaneous modules import matplotlib.pyplot as plt import cv2 import numpy as np +import yolk +from yolk.parser import parse_args + -def main(): - image = np.asarray(Image.open('000000008021.jpg').convert('RGB')) - image = image[:, :, ::-1].copy() +def main(args=None): + if args is None: + args = sys.argv[1:] + + args = parse_args(args) + + image = Image.open('./000000008021.jpg') image, scale = yolk.detector.preprocessing_image(image) model_path = os.path.join('..', 'resnet50_coco_best_v2.1.0.h5') - model = yolk.detector.load_model(model_path) + model = yolk.detector.load_inference_model(model_path, args) - boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) + model_output = model.predict_on_batch(np.expand_dims(image, axis=0)) - print(boxes) + print(model_output) + + #shape = image.shape + #boxes, scores, labels = yolk.detector.postprocessing(model_output, original_shape=image.shape, args) + #print(boxes) if __name__ == '__main__': main() diff --git a/examples/retinanet_training_example.py b/examples/retinanet_training_example.py new file mode 100644 index 0000000..3a6f0a6 --- /dev/null +++ b/examples/retinanet_training_example.py @@ -0,0 +1,40 @@ +import sys, os +sys.path.insert(0, os.path.abspath('..')) +from PIL import Image +import matplotlib.pyplot as plt +import cv2 +import numpy as np +import yolk +from yolk.parser import parse_args +import keras + +def main(args=None): + + if args is None: + args = sys.argv[1:] + + args = parse_args(args) + + model = yolk.detector.load_training_model(20, args) + + train_generator = yolk.detector.get_data_generator(args) + + model.compile( + loss=yolk.detector.get_losses(args), + optimizer=keras.optimizers.adam(lr=args.lr, clipnorm=0.001) + ) + + model.fit_generator( + generator=train_generator, + steps_per_epoch=10000, + epochs=50, + verbose=1, + callbacks=None, + workers=1, + use_multiprocessing=True, + max_queue_size=10, + validation_data=None + ) + +if __name__ == '__main__': + main() diff --git a/yolk/__init__.py b/yolk/__init__.py index e73865e..186efdf 100644 --- a/yolk/__init__.py +++ b/yolk/__init__.py @@ -1,3 +1,4 @@ from __future__ import absolute_import from . import backend -from .detector import * \ No newline at end of file +from .detector import * +from .parser import * diff --git a/yolk/backend/__init__.py b/yolk/backend/__init__.py index fe0e3f9..2502790 100644 --- a/yolk/backend/__init__.py +++ b/yolk/backend/__init__.py @@ -1,3 +1,6 @@ from .load_backend import backend -from .load_backend import load_model +from .load_backend import load_inference_model +from .load_backend import load_training_model from .load_backend import preprocess_image +from .load_backend import get_losses +from .load_backend import create_generators diff --git a/yolk/backend/retinanet_backend.py b/yolk/backend/retinanet_backend.py index 1fe5951..f67d208 100644 --- a/yolk/backend/retinanet_backend.py +++ b/yolk/backend/retinanet_backend.py @@ -1,16 +1,228 @@ import sys, os sys.path.insert(0, os.path.abspath('..')) +import numpy as np import keras_retinanet from keras_retinanet import models from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image +from keras_retinanet.models.retinanet import retinanet_bbox +from keras_retinanet import losses +from keras_retinanet.utils.transform import random_transform_generator +from keras_retinanet.utils.image import random_visual_effect_generator +from keras_retinanet.preprocessing.pascal_voc import PascalVocGenerator -def load_model(model_path, backbone='resnet50'): - model = models.load_model(model_path, backbone_name=backbone) - print("type:", type(model)) +def model_with_weights(model, weights, skip_mismatch): + """ Load weights for model. + + Args + model : The model to load weights for. + weights : The weights to load. + skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model. + """ + if weights is not None: + model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch) return model +def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, + freeze_backbone=False, lr=1e-5, config=None): + """ Creates three models (model, training_model, prediction_model). + + Args + backbone_retinanet : A function to call to create a retinanet model with a given backbone. + num_classes : The number of classes to train. + weights : The weights to load into the model. + multi_gpu : The number of GPUs to use for training. + freeze_backbone : If True, disables learning for the backbone. + config : Config parameters, None indicates the default configuration. + + Returns + model : The base model. This is also the model that is saved in snapshots. + training_model : The training model. If multi_gpu=0, this is identical to model. + prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS). + """ + + modifier = freeze_model if freeze_backbone else None + + # load anchor parameters, or pass None (so that defaults will be used) + anchor_params = None + num_anchors = None + if config and 'anchor_parameters' in config: + anchor_params = parse_anchor_parameters(config) + num_anchors = anchor_params.num_anchors() + + # Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors. + # optionally wrap in a parallel model + if multi_gpu > 1: + from keras.utils import multi_gpu_model + with tf.device('/cpu:0'): + model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True) + training_model = multi_gpu_model(model, gpus=multi_gpu) + else: + model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True) + training_model = model + + # make prediction model + prediction_model = retinanet_bbox(model=model, anchor_params=anchor_params) + + return model, training_model, prediction_model + +def load_inference_model(model_path, args): + return models.load_model(model_path, backbone_name=args.backbone) + +def load_training_model(num_classes, args): + backbone = models.backbone(args.backbone) + weights = backbone.download_imagenet() + + model, training_model, prediction_model = create_models( + backbone_retinanet=backbone.retinanet, + num_classes=num_classes, + weights=weights, + multi_gpu=args.multi_gpu, + freeze_backbone=args.freeze_backbone, + lr=args.lr, + config=args.config + ) + return training_model + def preprocess_image(image): + image = np.asarray(image)[:, :, ::-1] image = keras_retinanet.utils.image.preprocess_image(image) image, scale = resize_image(image) return image, scale + +def get_losses(args): + return { + 'regression' : losses.smooth_l1(), + 'classification': losses.focal() + } + +def create_generators(args): + + backbone = models.backbone(args.backbone) + preprocess_image = backbone.preprocess_image + + common_args = { + 'batch_size' : args.batch_size, + 'config' : args.config, + 'image_min_side' : args.image_min_side, + 'image_max_side' : args.image_max_side, + 'preprocess_image' : preprocess_image, + } + + # create random transform generator for augmenting training data + if args.random_transform: + transform_generator = random_transform_generator( + min_rotation=-0.1, + max_rotation=0.1, + min_translation=(-0.1, -0.1), + max_translation=(0.1, 0.1), + min_shear=-0.1, + max_shear=0.1, + min_scaling=(0.9, 0.9), + max_scaling=(1.1, 1.1), + flip_x_chance=0.5, + flip_y_chance=0.5, + ) + visual_effect_generator = random_visual_effect_generator( + contrast_range=(0.9, 1.1), + brightness_range=(-.1, .1), + hue_range=(-0.05, 0.05), + saturation_range=(0.95, 1.05) + ) + else: + transform_generator = random_transform_generator(flip_x_chance=0.5) + visual_effect_generator = None + + if args.dataset_type == 'coco': + # import here to prevent unnecessary dependency on cocoapi + from ..preprocessing.coco import CocoGenerator + + train_generator = CocoGenerator( + args.coco_path, + 'train2017', + transform_generator=transform_generator, + visual_effect_generator=visual_effect_generator, + **common_args + ) + + validation_generator = CocoGenerator( + args.coco_path, + 'val2017', + shuffle_groups=False, + **common_args + ) + elif args.dataset_type == 'pascal': + train_generator = PascalVocGenerator( + args.pascal_path, + 'trainval', + transform_generator=transform_generator, + visual_effect_generator=visual_effect_generator, + **common_args + ) + + validation_generator = PascalVocGenerator( + args.pascal_path, + 'test', + shuffle_groups=False, + **common_args + ) + elif args.dataset_type == 'csv': + train_generator = CSVGenerator( + args.annotations, + args.classes, + transform_generator=transform_generator, + visual_effect_generator=visual_effect_generator, + **common_args + ) + + if args.val_annotations: + validation_generator = CSVGenerator( + args.val_annotations, + args.classes, + shuffle_groups=False, + **common_args + ) + else: + validation_generator = None + elif args.dataset_type == 'oid': + train_generator = OpenImagesGenerator( + args.main_dir, + subset='train', + version=args.version, + labels_filter=args.labels_filter, + annotation_cache_dir=args.annotation_cache_dir, + parent_label=args.parent_label, + transform_generator=transform_generator, + visual_effect_generator=visual_effect_generator, + **common_args + ) + + validation_generator = OpenImagesGenerator( + args.main_dir, + subset='validation', + version=args.version, + labels_filter=args.labels_filter, + annotation_cache_dir=args.annotation_cache_dir, + parent_label=args.parent_label, + shuffle_groups=False, + **common_args + ) + elif args.dataset_type == 'kitti': + train_generator = KittiGenerator( + args.kitti_path, + subset='train', + transform_generator=transform_generator, + visual_effect_generator=visual_effect_generator, + **common_args + ) + + validation_generator = KittiGenerator( + args.kitti_path, + subset='val', + shuffle_groups=False, + **common_args + ) + else: + raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) + + return train_generator, validation_generator diff --git a/yolk/detector.py b/yolk/detector.py index 3fa05ca..4ac3e0d 100644 --- a/yolk/detector.py +++ b/yolk/detector.py @@ -1,7 +1,22 @@ from . import backend as M -def load_model(path, backbone='resnet50'): - return M.load_model(path, backbone) +def load_inference_model(path, args): + return M.load_inference_model(path, args) -def preprocessing_image(image): +def load_training_model(num_classes, args): + return M.load_training_model(num_classes, args) + +def preprocessing_image(image, args): return M.preprocess_image(image) + +def get_data_generator(args): + train_generator, validation_generator = M.create_generators(args) + return train_generator + +def prostprocessing_image(model_output, args): + #bbox ,score ,~ = .prostprocessing_image(model_output) + return 1#bbox ,score ,~ + +def get_losses(args): + return M.get_losses(args) + diff --git a/yolk/parser.py b/yolk/parser.py new file mode 100644 index 0000000..250d71d --- /dev/null +++ b/yolk/parser.py @@ -0,0 +1,65 @@ +import argparse + +def parse_args(args): + """ Parse the arguments. + """ + parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.') + subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') + subparsers.required = True + + coco_parser = subparsers.add_parser('coco') + coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') + + pascal_parser = subparsers.add_parser('pascal') + pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).') + + kitti_parser = subparsers.add_parser('kitti') + kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).') + + def csv_list(string): + return string.split(',') + + oid_parser = subparsers.add_parser('oid') + oid_parser.add_argument('main_dir', help='Path to dataset directory.') + oid_parser.add_argument('--version', help='The current dataset version is v4.', default='v4') + oid_parser.add_argument('--labels-filter', help='A list of labels to filter.', type=csv_list, default=None) + oid_parser.add_argument('--annotation-cache-dir', help='Path to store annotation cache.', default='.') + oid_parser.add_argument('--parent-label', help='Use the hierarchy children of this label.', default=None) + + csv_parser = subparsers.add_parser('csv') + csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.') + csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') + csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).') + + group = parser.add_mutually_exclusive_group() + group.add_argument('--snapshot', help='Resume training from a snapshot.') + group.add_argument('--imagenet-weights', help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True) + group.add_argument('--weights', help='Initialize the model with weights from a file.') + group.add_argument('--no-weights', help='Don\'t initialize the model with any weights.', dest='imagenet_weights', action='store_const', const=False) + + parser.add_argument('--backbone', help='Backbone model used by retinanet.', default='resnet50', type=str) + parser.add_argument('--batch-size', help='Size of the batches.', default=1, type=int) + parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).') + parser.add_argument('--multi-gpu', help='Number of GPUs to use for parallel processing.', type=int, default=0) + parser.add_argument('--multi-gpu-force', help='Extra flag needed to enable (experimental) multi-gpu support.', action='store_true') + parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=50) + parser.add_argument('--steps', help='Number of steps per epoch.', type=int, default=10000) + parser.add_argument('--lr', help='Learning rate.', type=float, default=1e-5) + parser.add_argument('--snapshot-path', help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots') + parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output', default='./logs') + parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false') + parser.add_argument('--no-evaluation', help='Disable per epoch evaluation.', dest='evaluation', action='store_false') + parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', action='store_true') + parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true') + parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800) + parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333) + parser.add_argument('--config', help='Path to a configuration parameters .ini file.') + parser.add_argument('--weighted-average', help='Compute the mAP using the weighted average of precisions among classes.', action='store_true') + parser.add_argument('--compute-val-loss', help='Compute validation loss during training', dest='compute_val_loss', action='store_true') + + # Fit generator arguments + parser.add_argument('--multiprocessing', help='Use multiprocessing in fit_generator.', action='store_true') + parser.add_argument('--workers', help='Number of generator workers.', type=int, default=1) + parser.add_argument('--max-queue-size', help='Queue length for multiprocessing workers in fit_generator.', type=int, default=10) + + return parser.parse_args(args) \ No newline at end of file