-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotting_helpers.py
127 lines (105 loc) · 3.72 KB
/
plotting_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
from torchvision import models
def test_network(net, trainloader):
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
dataiter = iter(trainloader)
images, labels = dataiter.next()
# Create Variables for the inputs and targets
inputs = Variable(images)
targets = Variable(images)
# Clear the gradients from all Variables
optimizer.zero_grad()
# Forward pass, then backward pass, then update weights
output = net.forward(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
return True
def imshow(image, ax=None, title=None, normalize=True):
"""Imshow for Tensor."""
if ax is None:
fig, ax = plt.subplots()
image = image.numpy().transpose((1, 2, 0))
if normalize:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image + mean
image = np.clip(image, 0, 1)
ax.imshow(image)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='both', length=0)
ax.set_xticklabels('')
ax.set_yticklabels('')
return ax
def view_recon(img, recon):
''' Function for displaying an image (as a PyTorch Tensor) and its
reconstruction also a PyTorch Tensor
'''
fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True)
axes[0].imshow(img.numpy().squeeze())
axes[1].imshow(recon.data.numpy().squeeze())
for ax in axes:
ax.axis('off')
ax.set_adjustable('box-forced')
def view_classify(img, ps, version="MNIST"):
''' Function for viewing an image and it's predicted classes.
'''
ps = ps.data.numpy().squeeze()
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
ax1.axis('off')
ax2.barh(np.arange(10), ps)
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
if version == "MNIST":
ax2.set_yticklabels(np.arange(10))
elif version == "Fashion":
ax2.set_yticklabels(['T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle Boot'], size='small')
ax2.set_title('Class Probability')
ax2.set_xlim(0, 1.1)
plt.tight_layout()
def save_model(model,path,verbose=True):
checkpoint = {
'parameters':model.parameters,
'state_dict':model.state_dict()
}
if verbose:
print("Checkpoint created.")
torch.save(checkpoint,path)
if verbose:
print("Checkpoint saved.")
def load_model(path):
try:
checkpoint = torch.load(path, map_location='cpu')
except Exception as err:
print(err)
return None
model = models.densenet121(pretrained=False)
model.classifier = nn.Sequential(nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 2),
nn.LogSoftmax(dim=1))
model.parameters = checkpoint['parameters']
model.load_state_dict(checkpoint['state_dict'])
return model