-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdetection_torch.py
138 lines (102 loc) · 3.94 KB
/
detection_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import open3d as o3d
import open3d.ml as _ml3d
import open3d.ml.torch as ml3d
from open3d.ml.vis import Visualizer, BoundingBox3D, LabelLUT
from tqdm import tqdm
import time
import numpy as np
import glob
def prepare_point_cloud_for_inference(pcd):
# Remove NaNs and infinity values
pcd.remove_non_finite_points()
# Extract the xyz points
xyz = pcd.points
# PointPillars classifier needs a 4th dimension (intensity), which my custom data does not have.
# We add it here with default value of 0.5
xyzi = []
for point in xyz:
xyzi.append(list(point) + [0.5])
xyzi = np.array(xyzi)
# Set the points to the correct format for inference
data = {"point":xyzi, 'feat': None, 'label':np.zeros((len(xyz),), dtype=np.int32)}
return data, pcd
def load_custom_dataset(dataset_path, candidates_number = 1000, step = 1):
print("Loading custom dataset")
pcd_paths = glob.glob(dataset_path+"/*.pcd")
pcds = []
for count, pcd_path in enumerate(pcd_paths):
if count % step == 0:
pcds.append(o3d.io.read_point_cloud(pcd_path))
if count == candidates_number:
break
return pcds
def filter_detections(detections, min_conf = 0.5):
good_detections = []
for detection in detections:
if detection.confidence >= min_conf:
good_detections.append(detection)
return good_detections
# Load an ML configuration file
cfg_file = "/home/carlos/Open3D/build/Open3D-ML/ml3d/configs/pointpillars_kitti.yml"
cfg = _ml3d.utils.Config.load_from_file(cfg_file)
# Load the PointPillars model
model = ml3d.models.PointPillars(**cfg.model)
# Add path to the Kitti dataset and your own custom dataset
cfg.dataset['dataset_path'] = '/media/carlos/SeagateExpansionDrive/kitti/Kitti'
cfg.dataset['custom_dataset_path'] = './pcds'
# Load the datasets
dataset = ml3d.datasets.KITTI(cfg.dataset.pop('dataset_path', None), **cfg.dataset)
custom_dataset = load_custom_dataset(cfg.dataset.pop('custom_dataset_path', None))
# Create the ML pipeline
pipeline = ml3d.pipelines.ObjectDetection(model, dataset=dataset, device="gpu", **cfg.pipeline)
# download the weights.
ckpt_folder = "./logs/"
os.makedirs(ckpt_folder, exist_ok=True)
ckpt_path = ckpt_folder + "pointpillars_kitti_202012221652utc.pth"
pointpillar_url = "https://storage.googleapis.com/open3d-releases/model-zoo/pointpillars_kitti_202012221652utc.pth"
if not os.path.exists(ckpt_path):
cmd = "wget {} -O {}".format(pointpillar_url, ckpt_path)
os.system(cmd)
# load the parameters of the model
pipeline.load_ckpt(ckpt_path=ckpt_path)
# Select the test split of the Kitti dataset
test_split = dataset.get_split("test")
# Prepare the visualizer
vis = Visualizer()
# Variable to accumulate the predictions
data_list = []
# Let's detect objects in the first few point clouds of the Kitti set
for idx in tqdm(range(10)):
# Get one test point cloud from the SemanticKitti dataset
data = test_split.get_data(idx)
# Run the inference
result = pipeline.run_inference(data)[0]
# Filter out results with low confidence
result = filter_detections(result)
# Prepare a dictionary usable by the visulization tool
pred = {
"name": 'KITTI' + '_' + str(idx),
'points': data['point'],
'bounding_boxes': result
}
# Append the data to the list
data_list.append(pred)
# Let's detect objects in the first few point clouds of the custom set
for idx in tqdm(range(len(custom_dataset))):
# Get one point cloud and format it for inference
data, pcd = prepare_point_cloud_for_inference(custom_dataset[idx])
# Run the inference
result = pipeline.run_inference(data)[0]
# Filter out results with low confidence
result = filter_detections(result, min_conf = 0.3)
# Prepare a dictionary usable by the visulization tool
pred = {
"name": 'Custom' + '_' + str(idx),
'points': data['point'],
'bounding_boxes': result
}
# Append the data to the list
data_list.append(pred)
# Visualize the results
vis.visualize(data_list, None, bounding_boxes=None)