Skip to content

Commit

Permalink
Merge pull request #49 from visionNoob/issue_45
Browse files Browse the repository at this point in the history
Add training for RetinaNet (#45, #47)
  • Loading branch information
MijeongJeon authored Oct 20, 2019
2 parents b8f1d5c + 70c8a0c commit dc6e38f
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 17 deletions.
26 changes: 17 additions & 9 deletions examples/retinanet_inference_example.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
import sys, os
sys.path.insert(0, os.path.abspath('..'))
import yolk
from PIL import Image

# import miscellaneous modules
import matplotlib.pyplot as plt
import cv2
import numpy as np
import yolk
from yolk.parser import parse_args


def main():
image = np.asarray(Image.open('000000008021.jpg').convert('RGB'))
image = image[:, :, ::-1].copy()
def main(args=None):
if args is None:
args = sys.argv[1:]

args = parse_args(args)

image = Image.open('./000000008021.jpg')
image, scale = yolk.detector.preprocessing_image(image)

model_path = os.path.join('..', 'resnet50_coco_best_v2.1.0.h5')
model = yolk.detector.load_model(model_path)
model = yolk.detector.load_inference_model(model_path, args)

boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
model_output = model.predict_on_batch(np.expand_dims(image, axis=0))

print(boxes)
print(model_output)

#shape = image.shape
#boxes, scores, labels = yolk.detector.postprocessing(model_output, original_shape=image.shape, args)
#print(boxes)

if __name__ == '__main__':
main()
40 changes: 40 additions & 0 deletions examples/retinanet_training_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import sys, os
sys.path.insert(0, os.path.abspath('..'))
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import numpy as np
import yolk
from yolk.parser import parse_args
import keras

def main(args=None):

if args is None:
args = sys.argv[1:]

args = parse_args(args)

model = yolk.detector.load_training_model(20, args)

train_generator = yolk.detector.get_data_generator(args)

model.compile(
loss=yolk.detector.get_losses(args),
optimizer=keras.optimizers.adam(lr=args.lr, clipnorm=0.001)
)

model.fit_generator(
generator=train_generator,
steps_per_epoch=10000,
epochs=50,
verbose=1,
callbacks=None,
workers=1,
use_multiprocessing=True,
max_queue_size=10,
validation_data=None
)

if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion yolk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import absolute_import
from . import backend
from .detector import *
from .detector import *
from .parser import *
5 changes: 4 additions & 1 deletion yolk/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .load_backend import backend
from .load_backend import load_model
from .load_backend import load_inference_model
from .load_backend import load_training_model
from .load_backend import preprocess_image
from .load_backend import get_losses
from .load_backend import create_generators
218 changes: 215 additions & 3 deletions yolk/backend/retinanet_backend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,228 @@
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import keras_retinanet
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.models.retinanet import retinanet_bbox
from keras_retinanet import losses
from keras_retinanet.utils.transform import random_transform_generator
from keras_retinanet.utils.image import random_visual_effect_generator
from keras_retinanet.preprocessing.pascal_voc import PascalVocGenerator

def load_model(model_path, backbone='resnet50'):
model = models.load_model(model_path, backbone_name=backbone)
print("type:", type(model))
def model_with_weights(model, weights, skip_mismatch):
""" Load weights for model.
Args
model : The model to load weights for.
weights : The weights to load.
skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model.
"""
if weights is not None:
model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch)
return model

def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0,
freeze_backbone=False, lr=1e-5, config=None):
""" Creates three models (model, training_model, prediction_model).
Args
backbone_retinanet : A function to call to create a retinanet model with a given backbone.
num_classes : The number of classes to train.
weights : The weights to load into the model.
multi_gpu : The number of GPUs to use for training.
freeze_backbone : If True, disables learning for the backbone.
config : Config parameters, None indicates the default configuration.
Returns
model : The base model. This is also the model that is saved in snapshots.
training_model : The training model. If multi_gpu=0, this is identical to model.
prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS).
"""

modifier = freeze_model if freeze_backbone else None

# load anchor parameters, or pass None (so that defaults will be used)
anchor_params = None
num_anchors = None
if config and 'anchor_parameters' in config:
anchor_params = parse_anchor_parameters(config)
num_anchors = anchor_params.num_anchors()

# Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors.
# optionally wrap in a parallel model
if multi_gpu > 1:
from keras.utils import multi_gpu_model
with tf.device('/cpu:0'):
model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True)
training_model = multi_gpu_model(model, gpus=multi_gpu)
else:
model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True)
training_model = model

# make prediction model
prediction_model = retinanet_bbox(model=model, anchor_params=anchor_params)

return model, training_model, prediction_model

def load_inference_model(model_path, args):
return models.load_model(model_path, backbone_name=args.backbone)

def load_training_model(num_classes, args):
backbone = models.backbone(args.backbone)
weights = backbone.download_imagenet()

model, training_model, prediction_model = create_models(
backbone_retinanet=backbone.retinanet,
num_classes=num_classes,
weights=weights,
multi_gpu=args.multi_gpu,
freeze_backbone=args.freeze_backbone,
lr=args.lr,
config=args.config
)
return training_model

def preprocess_image(image):
image = np.asarray(image)[:, :, ::-1]
image = keras_retinanet.utils.image.preprocess_image(image)
image, scale = resize_image(image)
return image, scale

def get_losses(args):
return {
'regression' : losses.smooth_l1(),
'classification': losses.focal()
}

def create_generators(args):

backbone = models.backbone(args.backbone)
preprocess_image = backbone.preprocess_image

common_args = {
'batch_size' : args.batch_size,
'config' : args.config,
'image_min_side' : args.image_min_side,
'image_max_side' : args.image_max_side,
'preprocess_image' : preprocess_image,
}

# create random transform generator for augmenting training data
if args.random_transform:
transform_generator = random_transform_generator(
min_rotation=-0.1,
max_rotation=0.1,
min_translation=(-0.1, -0.1),
max_translation=(0.1, 0.1),
min_shear=-0.1,
max_shear=0.1,
min_scaling=(0.9, 0.9),
max_scaling=(1.1, 1.1),
flip_x_chance=0.5,
flip_y_chance=0.5,
)
visual_effect_generator = random_visual_effect_generator(
contrast_range=(0.9, 1.1),
brightness_range=(-.1, .1),
hue_range=(-0.05, 0.05),
saturation_range=(0.95, 1.05)
)
else:
transform_generator = random_transform_generator(flip_x_chance=0.5)
visual_effect_generator = None

if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
from ..preprocessing.coco import CocoGenerator

train_generator = CocoGenerator(
args.coco_path,
'train2017',
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

validation_generator = CocoGenerator(
args.coco_path,
'val2017',
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'pascal':
train_generator = PascalVocGenerator(
args.pascal_path,
'trainval',
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

validation_generator = PascalVocGenerator(
args.pascal_path,
'test',
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'csv':
train_generator = CSVGenerator(
args.annotations,
args.classes,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

if args.val_annotations:
validation_generator = CSVGenerator(
args.val_annotations,
args.classes,
shuffle_groups=False,
**common_args
)
else:
validation_generator = None
elif args.dataset_type == 'oid':
train_generator = OpenImagesGenerator(
args.main_dir,
subset='train',
version=args.version,
labels_filter=args.labels_filter,
annotation_cache_dir=args.annotation_cache_dir,
parent_label=args.parent_label,
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

validation_generator = OpenImagesGenerator(
args.main_dir,
subset='validation',
version=args.version,
labels_filter=args.labels_filter,
annotation_cache_dir=args.annotation_cache_dir,
parent_label=args.parent_label,
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'kitti':
train_generator = KittiGenerator(
args.kitti_path,
subset='train',
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

validation_generator = KittiGenerator(
args.kitti_path,
subset='val',
shuffle_groups=False,
**common_args
)
else:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

return train_generator, validation_generator
21 changes: 18 additions & 3 deletions yolk/detector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from . import backend as M

def load_model(path, backbone='resnet50'):
return M.load_model(path, backbone)
def load_inference_model(path, args):
return M.load_inference_model(path, args)

def preprocessing_image(image):
def load_training_model(num_classes, args):
return M.load_training_model(num_classes, args)

def preprocessing_image(image, args):
return M.preprocess_image(image)

def get_data_generator(args):
train_generator, validation_generator = M.create_generators(args)
return train_generator

def prostprocessing_image(model_output, args):
#bbox ,score ,~ = .prostprocessing_image(model_output)
return 1#bbox ,score ,~

def get_losses(args):
return M.get_losses(args)

Loading

0 comments on commit dc6e38f

Please sign in to comment.