Skip to content

Commit

Permalink
Fix inference issue with large image (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
sltlls authored Jun 22, 2022
1 parent 29f7991 commit 01439d1
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions mmrotate/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ def inference_detector_by_patches(model,
sizes, steps = get_multiscale_patch(sizes, steps, ratios)
windows = slide_window(width, height, sizes, steps)

# prepare patch data
patch_datas = []
for window in windows:
data = dict(img=img, win=window.tolist())
# build the data pipeline
data = test_pipeline(data)
patch_datas.append(data)

results = []
start = 0
while True:
data = patch_datas[start:start + bs]
data = collate(data, samples_per_gpu=len(data))
# prepare patch data
patch_datas = []
if (start + bs) > len(windows):
end = len(windows)
else:
end = start + bs
for window in windows[start:end]:
data = dict(img=img, win=window.tolist())
data = test_pipeline(data)
patch_datas.append(data)
data = collate(patch_datas, samples_per_gpu=len(patch_datas))
# just get the actual data from DataContainer
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
Expand All @@ -80,7 +81,7 @@ def inference_detector_by_patches(model,
with torch.no_grad():
results.extend(model(return_loss=False, rescale=True, **data))

if start + bs >= len(patch_datas):
if end >= len(windows):
break
start += bs

Expand Down

0 comments on commit 01439d1

Please sign in to comment.