Skip to content

Commit

Permalink
Update util.py
Browse files Browse the repository at this point in the history
add py_nms, py_box_overlap and average_precision
  • Loading branch information
tanghaotommy authored Jun 24, 2020
1 parent 5fbd776 commit 364621c
Showing 1 changed file with 118 additions and 0 deletions.
118 changes: 118 additions & 0 deletions utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,124 @@ def npy2submission(set_name, save_path, bbox_dir, prep_dir, postfix='detection')
submission.to_csv(save_path, sep=',', index=False, header=True)


def average_precision(labels, y_pred):
# Compute number of objects
true_objects = len(np.unique(labels))
pred_objects = len(np.unique(y_pred))
print("Number of true objects:", true_objects)
print("Number of predicted objects:", pred_objects)

# Compute intersection between all objects
intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]

# Compute areas (needed for finding the union between all objects)
area_true = np.histogram(labels, bins = true_objects)[0]
area_pred = np.histogram(y_pred, bins = pred_objects)[0]
area_true = np.expand_dims(area_true, -1)
area_pred = np.expand_dims(area_pred, 0)

# Compute union
union = area_true + area_pred - intersection

# Exclude background from the analysis
intersection = intersection[1:,1:]
union = union[1:,1:]
union[union == 0] = 1e-9

# Compute the intersection over union
iou = intersection / union
dice = 2 * intersection / (union + intersection)

# Precision helper function
def precision_at(threshold, iou):
matches = iou > threshold
true_positives = np.sum(matches, axis=1) == 1 # Correct objects
false_positives = np.sum(matches, axis=0) == 0 # Missed objects
false_negatives = np.sum(matches, axis=1) == 0 # Extra objects
tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
return tp, fp, fn

# Loop over IoU thresholds
prec = []
# print("Thresh\tTP\tFP\tFN\tPrec.")
for t in np.arange(0.5, 1.0, 0.05):
tp, fp, fn = precision_at(t, iou)
p = tp / float(tp + fp + fn)
# print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
prec.append(p)
# print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
return prec, np.max(dice, axis=1)


def py_nms(dets, thresh):
# Check the input dtype
if isinstance(dets, torch.Tensor):
if dets.is_cuda:
dets = dets.cpu()
dets = dets.data.numpy()

z = dets[:, 1]
y = dets[:, 2]
x = dets[:, 3]
d = dets[:, 4]
h = dets[:, 5]
w = dets[:, 6]
scores = dets[:, 0]

areas = d * h * w
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)

xx0 = np.maximum(x[i] - w[i] / 2., x[order[1:]] - w[order[1:]] / 2.)
yy0 = np.maximum(y[i] - h[i] / 2., y[order[1:]] - h[order[1:]] / 2.)
zz0 = np.maximum(z[i] - d[i] / 2., z[order[1:]] - d[order[1:]] / 2.)
xx1 = np.minimum(x[i] + w[i] / 2., x[order[1:]] + w[order[1:]] / 2.)
yy1 = np.minimum(y[i] + h[i] / 2., y[order[1:]] + h[order[1:]] / 2.)
zz1 = np.minimum(z[i] + d[i] / 2., z[order[1:]] + d[order[1:]] / 2.)

inter_w = np.maximum(0.0, xx1 - xx0)
inter_h = np.maximum(0.0, yy1 - yy0)
inter_d = np.maximum(0.0, zz1 - zz0)
intersect = inter_w * inter_h * inter_d
overlap = intersect / (areas[i] + areas[order[1:]] - intersect)

inds = np.where(overlap <= thresh)[0]
order = order[inds + 1]

return torch.from_numpy(dets[keep]), torch.LongTensor(keep)


def py_box_overlap(boxes1, boxes2):
overlap = np.zeros((len(boxes1), len(boxes2)))

z1, y1, x1 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2]
d1, h1, w1 = boxes1[:, 3], boxes1[:, 4], boxes1[:, 5]
areas1 = d1 * h1 * w1

z2, y2, x2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2]
d2, h2, w2 = boxes2[:, 3], boxes2[:, 4], boxes2[:, 5]
areas2 = d2 * h2 * w2

for i in range(len(boxes1)):
xx0 = np.maximum(x1[i] - w1[i] / 2., x2 - w2 / 2.)
yy0 = np.maximum(y1[i] - h1[i] / 2., y2 - h2 / 2.)
zz0 = np.maximum(z1[i] - d1[i] / 2., z2 - d2 / 2.)
xx1 = np.minimum(x1[i] + w1[i] / 2., x2 + w2 / 2.)
yy1 = np.minimum(y1[i] + h1[i] / 2., y2 + h2 / 2.)
zz1 = np.minimum(z1[i] + d1[i] / 2., z2 + d2 / 2.)

inter_w = np.maximum(0.0, xx1 - xx0)
inter_h = np.maximum(0.0, yy1 - yy0)
inter_d = np.maximum(0.0, zz1 - zz0)
intersect = inter_w * inter_h * inter_d
overlap[i] = intersect / (areas1[i] + areas2 - intersect)

return overlap


def center_box_to_coord_box(bboxes):
"""
Convert bounding box using center of rectangle and side lengths representation to
Expand Down

0 comments on commit 364621c

Please sign in to comment.