-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain_fully_supervised.py
83 lines (61 loc) · 2.03 KB
/
main_fully_supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import hydra
from omegaconf import DictConfig
import ignite.distributed as idist
from ignite.utils import manual_seed, setup_logger
import utils
import trainers
def training(local_rank, cfg):
logger = setup_logger(
"Fully-Supervised Training", distributed_rank=idist.get_rank()
)
if local_rank == 0:
logger.info(cfg.pretty())
rank = idist.get_rank()
manual_seed(cfg.seed + rank)
device = idist.device()
model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg)
supervised_train_loader, test_loader, *_ = utils.get_dataflow(cfg)
def train_step(engine, batch):
model.train()
optimizer.zero_grad()
x = batch["sup_batch"]["image"]
y = batch["sup_batch"]["target"]
if x.device != device:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
y_pred = model(x)
sup_loss = sup_criterion(y_pred, y)
sup_loss.backward()
optimizer.step()
return {
"sup_loss": sup_loss.item(),
}
trainer = trainers.create_trainer(
train_step,
output_names=["sup_loss",],
model=model,
ema_model=ema_model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
supervised_train_loader=supervised_train_loader,
test_loader=test_loader,
cfg=cfg,
logger=logger,
)
epoch_length = cfg.solver.epoch_length
num_epochs = cfg.solver.num_epochs if not cfg.debug else 2
try:
trainer.run(
supervised_train_loader, epoch_length=epoch_length, max_epochs=num_epochs
)
except Exception as e:
import traceback
print(traceback.format_exc())
@hydra.main(config_path="config", config_name="fully_supervised")
def main(cfg: DictConfig) -> None:
with idist.Parallel(
backend=cfg.distributed.backend, nproc_per_node=cfg.distributed.nproc_per_node
) as parallel:
parallel.run(training, cfg)
if __name__ == "__main__":
main()