From 5547563bca6177a5f6109032faf3d7acc88c8cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=98=E9=BB=98?= Date: Mon, 2 Dec 2019 12:48:18 +0800 Subject: [PATCH] fix bug after train 2019-12-02 11:07:34,096 maskrcnn_benchmark.inference INFO: Start evaluation on 233 images 0it [00:01, ?it/s] Traceback (most recent call last): File "tools/train_net.py", line 174, in main() File "tools/train_net.py", line 170, in main test(cfg, model, args.distributed) File "tools/train_net.py", line 106, in test output_folder=output_folder, File "~/MaskTextSpotter/maskrcnn_benchmark/engine/inference.py", line 371, in inference predictions = compute_on_dataset(model, data_loader, device) File "~/MaskTextSpotter/maskrcnn_benchmark/engine/inference.py", line 32, in compute_on_dataset output = [o.to(cpu_device) for o in output] File "~/MaskTextSpotter/maskrcnn_benchmark/engine/inference.py", line 32, in output = [o.to(cpu_device) for o in output] AttributeError: 'list' object has no attribute 'to' --- tools/train_net.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/train_net.py b/tools/train_net.py index 233ec93..fc61969 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -15,7 +15,7 @@ from maskrcnn_benchmark.data import make_data_loader from maskrcnn_benchmark.solver import make_lr_scheduler from maskrcnn_benchmark.solver import make_optimizer -from maskrcnn_benchmark.engine.inference import inference +from maskrcnn_benchmark.engine.text_inference import inference from maskrcnn_benchmark.engine.trainer import do_train from maskrcnn_benchmark.modeling.detector import build_detection_model from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer @@ -94,6 +94,7 @@ def test(cfg, model, distributed): mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) + model_name = cfg.MODEL.WEIGHT.split('/')[-1] for output_folder, data_loader_val in zip(output_folders, data_loaders_val): inference( model, @@ -104,6 +105,8 @@ def test(cfg, model, distributed): expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=output_folder, + model_name=model_name, + cfg=cfg, ) synchronize()