-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdemo.py
52 lines (40 loc) · 1.5 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from lapsrn import *
from PIL import Image, ImageFilter
import torchvision.transforms.functional as tf
from torchvision import transforms
def load_ckp(checkpoint_fpath, model, optimizer):
"""
checkpoint_path: path to save checkpoint
model: model that we want to load checkpoint parameters into
optimizer: optimizer we defined in previous training
"""
# load check point
checkpoint = torch.load(checkpoint_fpath)
# initialize state_dict from checkpoint to model
model.load_state_dict(checkpoint['state_dict'])
# initialize optimizer from checkpoint to optimizer
optimizer.load_state_dict(checkpoint['optimizer'])
# initialize valid_loss_min from checkpoint to valid_loss_min
valid_loss_min = checkpoint['valid_loss_min']
# return model, optimizer, epoch value, min validation loss
return model, optimizer, checkpoint['epoch'], valid_loss_min.item()
def get_y(img):
img = img.convert('YCbCr')
img = img.getchannel(0)
return img
checkpoint = torch.load('best.pt', map_location='cuda:0')
net = LapSrnMS(5, 5, 4)
net.load_state_dict(checkpoint['state_dict'])
net.to('cuda:0')
im_4x = get_y(Image.open("pr-curve.jpeg"))
im = tf.to_tensor(im_4x)
im = im.unsqueeze(0)
im = im.to('cuda:0')
with torch.no_grad():
out_2x, out_4x = net(im)
out_2x[out_2x > 1] = 1
out_4x[out_4x > 1] = 1
out_2x = transforms.ToPILImage()(out_2x[0].cpu())
out_4x = transforms.ToPILImage()(out_4x[0].cpu())
out_2x.save("out_2x.png", "PNG")
out_4x.save("out_4x.png", "PNG")