forked from AnyLoc/AnyLoc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexplore_SAM.py
150 lines (136 loc) · 5 KB
/
explore_SAM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import sys
from pathlib import Path
# Set the './../' from the script folder
dir_name = None
try:
dir_name = os.path.dirname(os.path.realpath(__file__))
except NameError:
print('WARN: __file__ not found, trying local')
dir_name = os.path.abspath('')
lib_path = os.path.realpath(f'{Path(dir_name).parent}')
# Add to path
if lib_path not in sys.path:
print(f'Adding library path: {lib_path} to PYTHONPATH')
sys.path.append(lib_path)
else:
print(f'Library path {lib_path} already in PYTHONPATH')
# %%
import torch
from torch.nn import functional as F
# from dino_extractor import ViTExtractor
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from PIL import Image
import numpy as np
import tyro
from dataclasses import dataclass, field
from utilities import VLAD, get_top_k_recall, seed_everything
import einops as ein
import wandb
import matplotlib.pyplot as plt
import time
import joblib
import traceback
from tqdm.auto import tqdm
from dvgl_benchmark.datasets_ws import BaseDataset
from configs import ProgArgs, prog_args, BaseDatasetArgs, \
base_dataset_args, device
from typing import Union, Literal, Tuple, List
from custom_datasets.baidu_dataloader import Baidu_Dataset
from custom_datasets.oxford_dataloader import Oxford
from custom_datasets.gardens import Gardens
@dataclass
class LocalArgs:
# Program arguments (dataset directories and wandb)
prog: ProgArgs = ProgArgs(wandb_proj="Dino-Descs",
wandb_group="Direct-Descs")
# BaseDataset arguments
bd_args: BaseDatasetArgs = base_dataset_args
# Experiment identifier (None = don't use)
exp_id: Union[str, None] = None
# Dino parameters
model_type: Literal["dino_vits8", "dino_vits16", "dino_vitb8",
"dino_vitb16", "vit_small_patch8_224",
"vit_small_patch16_224", "vit_base_patch8_224",
"vit_base_patch16_224"] = "dino_vits8"
"""
Model for Dino to use as the base model.
"""
# Number of clusters for VLAD
num_clusters: int = 8
# Stride for ViT (extractor)
vit_stride: int = 4
# Down-scaling H, W resolution for images (before giving to Dino)
down_scale_res: Tuple[int, int] = (224, 298)
# Layer for extracting Dino feature (descriptors)
desc_layer: int = 11
# Facet for extracting descriptors
desc_facet: Literal["key", "query", "value", "token"] = "key"
# Apply log binning to the descriptor
desc_bin: bool = False
# Dataset split for VPR (BaseDataset)
data_split: Literal["train", "test", "val"] = "test"
# Sub-sample query images (RAM or VRAM constraints) (1 = off)
sub_sample_qu: int = 1
# Sub-sample database images (RAM or VRAM constraints) (1 = off)
sub_sample_db: int = 1
# Sub-sample database images for VLAD clustering only
sub_sample_db_vlad: int = 1
"""
Use sub-sampling for creating the VLAD cluster centers. Use
this to reduce the RAM usage during the clustering process.
Unlike `sub_sample_qu` and `sub_sample_db`, this is only used
for clustering and not for the actual VLAD computation.
"""
# Values for top-k (for monitoring)
top_k_vals: List[int] = field(default_factory=lambda:\
list(range(1, 21, 1)))
# Show a matplotlib plot for recalls
show_plot: bool = False
# Use hard or soft descriptor assignment for VLAD
vlad_assignment: Literal["hard", "soft"] = "hard"
# Softmax temperature for VLAD (soft assignment only)
vlad_soft_temp: float = 1.0
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
largs = tyro.cli(LocalArgs, description=__doc__)
ds_dir = largs.prog.data_vg_dir
ds_name = largs.prog.vg_dataset_name
if ds_name=="baidu_datasets":
vpr_ds = Baidu_Dataset(largs.bd_args, ds_dir, ds_name,
largs.data_split)
elif ds_name=="Oxford":
vpr_ds = Oxford(ds_dir)
elif ds_name=="gardens":
vpr_ds = Gardens(largs.bd_args,ds_dir,ds_name,largs.data_split)
else:
vpr_ds = BaseDataset(largs.bd_args, ds_dir, ds_name,
largs.data_split)
sam = sam_model_registry["vit_l"](checkpoint="/ocean/projects/cis220039p/jkarhade/data/sam_model/sam_vit_l_0b3195.pth")
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
img = vpr_ds[0][0]
img = ein.rearrange(img, "c h w -> 1 c h w").to(device)
img = F.interpolate(img, largs.down_scale_res)
img = img.squeeze(0)
masks = mask_generator.generate(img)
print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(img.permute(1,2,0).detach().cpu().numpy())
show_anns(masks)
plt.axis('off')
plt.show()