-
-
Notifications
You must be signed in to change notification settings - Fork 343
/
Copy pathrun_profile.py
93 lines (71 loc) · 2.46 KB
/
run_profile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import torch
import torch.nn as nn
from torch import Tensor
import re
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
from src.core import YAMLConfig, yaml_utils
from src.solver import TASKS
from typing import Dict, List, Optional, Any
__all__ = ["profile_stats"]
def profile_stats(
model: nn.Module,
data: Optional[Tensor]=None,
shape: List[int]=[1, 3, 640, 640],
verbose: bool=False
) -> Dict[str, Any]:
is_training = model.training
model.train()
num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
model.eval()
if data is None:
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
data = torch.rand(*shape, dtype=dtype, device=device)
print(device)
def trace_handler(prof):
print(prof.key_averages().table(sort_by='self_cuda_time_total', row_limit=-1))
wait = 0
warmup = 1
active = 1
repeat = 1
skip_first = 0
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait,
warmup=warmup,
active=active,
repeat=repeat,
skip_first=skip_first,
),
with_flops=True,
) as p:
n_step = skip_first + (wait + warmup + active) * repeat
for _ in range(n_step):
_ = model(data)
p.step()
if is_training:
model.train()
info = p.key_averages().table(sort_by='self_cuda_time_total', row_limit=-1)
num_flops = sum([float(v.strip()) for v in re.findall('(\d+.?\d+ *\n)', info)]) / active
if verbose:
print(info)
print(f'Total number of trainable parameters: {num_params}')
print(f'Total number of flops: {int(num_flops)}M with {shape}')
return {'n_parameters': num_params, 'n_flops': num_flops, 'info': info}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True)
parser.add_argument('-d', '--device', type=str, default='cuda:0', help='device',)
args = parser.parse_args()
cfg = YAMLConfig(args.config, device=args.device)
model = cfg.model.to(args.device)
profile_stats(model, verbose=True)