-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from visionNoob/issue_45
- Loading branch information
Showing
7 changed files
with
361 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,32 @@ | ||
import sys, os | ||
sys.path.insert(0, os.path.abspath('..')) | ||
import yolk | ||
from PIL import Image | ||
|
||
# import miscellaneous modules | ||
import matplotlib.pyplot as plt | ||
import cv2 | ||
import numpy as np | ||
import yolk | ||
from yolk.parser import parse_args | ||
|
||
|
||
def main(): | ||
image = np.asarray(Image.open('000000008021.jpg').convert('RGB')) | ||
image = image[:, :, ::-1].copy() | ||
def main(args=None): | ||
if args is None: | ||
args = sys.argv[1:] | ||
|
||
args = parse_args(args) | ||
|
||
image = Image.open('./000000008021.jpg') | ||
image, scale = yolk.detector.preprocessing_image(image) | ||
|
||
model_path = os.path.join('..', 'resnet50_coco_best_v2.1.0.h5') | ||
model = yolk.detector.load_model(model_path) | ||
model = yolk.detector.load_inference_model(model_path, args) | ||
|
||
boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) | ||
model_output = model.predict_on_batch(np.expand_dims(image, axis=0)) | ||
|
||
print(boxes) | ||
print(model_output) | ||
|
||
#shape = image.shape | ||
#boxes, scores, labels = yolk.detector.postprocessing(model_output, original_shape=image.shape, args) | ||
#print(boxes) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import sys, os | ||
sys.path.insert(0, os.path.abspath('..')) | ||
from PIL import Image | ||
import matplotlib.pyplot as plt | ||
import cv2 | ||
import numpy as np | ||
import yolk | ||
from yolk.parser import parse_args | ||
import keras | ||
|
||
def main(args=None): | ||
|
||
if args is None: | ||
args = sys.argv[1:] | ||
|
||
args = parse_args(args) | ||
|
||
model = yolk.detector.load_training_model(20, args) | ||
|
||
train_generator = yolk.detector.get_data_generator(args) | ||
|
||
model.compile( | ||
loss=yolk.detector.get_losses(args), | ||
optimizer=keras.optimizers.adam(lr=args.lr, clipnorm=0.001) | ||
) | ||
|
||
model.fit_generator( | ||
generator=train_generator, | ||
steps_per_epoch=10000, | ||
epochs=50, | ||
verbose=1, | ||
callbacks=None, | ||
workers=1, | ||
use_multiprocessing=True, | ||
max_queue_size=10, | ||
validation_data=None | ||
) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from __future__ import absolute_import | ||
from . import backend | ||
from .detector import * | ||
from .detector import * | ||
from .parser import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from .load_backend import backend | ||
from .load_backend import load_model | ||
from .load_backend import load_inference_model | ||
from .load_backend import load_training_model | ||
from .load_backend import preprocess_image | ||
from .load_backend import get_losses | ||
from .load_backend import create_generators |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,228 @@ | ||
import sys, os | ||
sys.path.insert(0, os.path.abspath('..')) | ||
|
||
import numpy as np | ||
import keras_retinanet | ||
from keras_retinanet import models | ||
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image | ||
from keras_retinanet.models.retinanet import retinanet_bbox | ||
from keras_retinanet import losses | ||
from keras_retinanet.utils.transform import random_transform_generator | ||
from keras_retinanet.utils.image import random_visual_effect_generator | ||
from keras_retinanet.preprocessing.pascal_voc import PascalVocGenerator | ||
|
||
def load_model(model_path, backbone='resnet50'): | ||
model = models.load_model(model_path, backbone_name=backbone) | ||
print("type:", type(model)) | ||
def model_with_weights(model, weights, skip_mismatch): | ||
""" Load weights for model. | ||
Args | ||
model : The model to load weights for. | ||
weights : The weights to load. | ||
skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model. | ||
""" | ||
if weights is not None: | ||
model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch) | ||
return model | ||
|
||
def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, | ||
freeze_backbone=False, lr=1e-5, config=None): | ||
""" Creates three models (model, training_model, prediction_model). | ||
Args | ||
backbone_retinanet : A function to call to create a retinanet model with a given backbone. | ||
num_classes : The number of classes to train. | ||
weights : The weights to load into the model. | ||
multi_gpu : The number of GPUs to use for training. | ||
freeze_backbone : If True, disables learning for the backbone. | ||
config : Config parameters, None indicates the default configuration. | ||
Returns | ||
model : The base model. This is also the model that is saved in snapshots. | ||
training_model : The training model. If multi_gpu=0, this is identical to model. | ||
prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS). | ||
""" | ||
|
||
modifier = freeze_model if freeze_backbone else None | ||
|
||
# load anchor parameters, or pass None (so that defaults will be used) | ||
anchor_params = None | ||
num_anchors = None | ||
if config and 'anchor_parameters' in config: | ||
anchor_params = parse_anchor_parameters(config) | ||
num_anchors = anchor_params.num_anchors() | ||
|
||
# Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors. | ||
# optionally wrap in a parallel model | ||
if multi_gpu > 1: | ||
from keras.utils import multi_gpu_model | ||
with tf.device('/cpu:0'): | ||
model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True) | ||
training_model = multi_gpu_model(model, gpus=multi_gpu) | ||
else: | ||
model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True) | ||
training_model = model | ||
|
||
# make prediction model | ||
prediction_model = retinanet_bbox(model=model, anchor_params=anchor_params) | ||
|
||
return model, training_model, prediction_model | ||
|
||
def load_inference_model(model_path, args): | ||
return models.load_model(model_path, backbone_name=args.backbone) | ||
|
||
def load_training_model(num_classes, args): | ||
backbone = models.backbone(args.backbone) | ||
weights = backbone.download_imagenet() | ||
|
||
model, training_model, prediction_model = create_models( | ||
backbone_retinanet=backbone.retinanet, | ||
num_classes=num_classes, | ||
weights=weights, | ||
multi_gpu=args.multi_gpu, | ||
freeze_backbone=args.freeze_backbone, | ||
lr=args.lr, | ||
config=args.config | ||
) | ||
return training_model | ||
|
||
def preprocess_image(image): | ||
image = np.asarray(image)[:, :, ::-1] | ||
image = keras_retinanet.utils.image.preprocess_image(image) | ||
image, scale = resize_image(image) | ||
return image, scale | ||
|
||
def get_losses(args): | ||
return { | ||
'regression' : losses.smooth_l1(), | ||
'classification': losses.focal() | ||
} | ||
|
||
def create_generators(args): | ||
|
||
backbone = models.backbone(args.backbone) | ||
preprocess_image = backbone.preprocess_image | ||
|
||
common_args = { | ||
'batch_size' : args.batch_size, | ||
'config' : args.config, | ||
'image_min_side' : args.image_min_side, | ||
'image_max_side' : args.image_max_side, | ||
'preprocess_image' : preprocess_image, | ||
} | ||
|
||
# create random transform generator for augmenting training data | ||
if args.random_transform: | ||
transform_generator = random_transform_generator( | ||
min_rotation=-0.1, | ||
max_rotation=0.1, | ||
min_translation=(-0.1, -0.1), | ||
max_translation=(0.1, 0.1), | ||
min_shear=-0.1, | ||
max_shear=0.1, | ||
min_scaling=(0.9, 0.9), | ||
max_scaling=(1.1, 1.1), | ||
flip_x_chance=0.5, | ||
flip_y_chance=0.5, | ||
) | ||
visual_effect_generator = random_visual_effect_generator( | ||
contrast_range=(0.9, 1.1), | ||
brightness_range=(-.1, .1), | ||
hue_range=(-0.05, 0.05), | ||
saturation_range=(0.95, 1.05) | ||
) | ||
else: | ||
transform_generator = random_transform_generator(flip_x_chance=0.5) | ||
visual_effect_generator = None | ||
|
||
if args.dataset_type == 'coco': | ||
# import here to prevent unnecessary dependency on cocoapi | ||
from ..preprocessing.coco import CocoGenerator | ||
|
||
train_generator = CocoGenerator( | ||
args.coco_path, | ||
'train2017', | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
validation_generator = CocoGenerator( | ||
args.coco_path, | ||
'val2017', | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
elif args.dataset_type == 'pascal': | ||
train_generator = PascalVocGenerator( | ||
args.pascal_path, | ||
'trainval', | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
validation_generator = PascalVocGenerator( | ||
args.pascal_path, | ||
'test', | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
elif args.dataset_type == 'csv': | ||
train_generator = CSVGenerator( | ||
args.annotations, | ||
args.classes, | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
if args.val_annotations: | ||
validation_generator = CSVGenerator( | ||
args.val_annotations, | ||
args.classes, | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
else: | ||
validation_generator = None | ||
elif args.dataset_type == 'oid': | ||
train_generator = OpenImagesGenerator( | ||
args.main_dir, | ||
subset='train', | ||
version=args.version, | ||
labels_filter=args.labels_filter, | ||
annotation_cache_dir=args.annotation_cache_dir, | ||
parent_label=args.parent_label, | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
validation_generator = OpenImagesGenerator( | ||
args.main_dir, | ||
subset='validation', | ||
version=args.version, | ||
labels_filter=args.labels_filter, | ||
annotation_cache_dir=args.annotation_cache_dir, | ||
parent_label=args.parent_label, | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
elif args.dataset_type == 'kitti': | ||
train_generator = KittiGenerator( | ||
args.kitti_path, | ||
subset='train', | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
validation_generator = KittiGenerator( | ||
args.kitti_path, | ||
subset='val', | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
else: | ||
raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) | ||
|
||
return train_generator, validation_generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,22 @@ | ||
from . import backend as M | ||
|
||
def load_model(path, backbone='resnet50'): | ||
return M.load_model(path, backbone) | ||
def load_inference_model(path, args): | ||
return M.load_inference_model(path, args) | ||
|
||
def preprocessing_image(image): | ||
def load_training_model(num_classes, args): | ||
return M.load_training_model(num_classes, args) | ||
|
||
def preprocessing_image(image, args): | ||
return M.preprocess_image(image) | ||
|
||
def get_data_generator(args): | ||
train_generator, validation_generator = M.create_generators(args) | ||
return train_generator | ||
|
||
def prostprocessing_image(model_output, args): | ||
#bbox ,score ,~ = .prostprocessing_image(model_output) | ||
return 1#bbox ,score ,~ | ||
|
||
def get_losses(args): | ||
return M.get_losses(args) | ||
|
Oops, something went wrong.