-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcaffe_trainloss_net_shell.m
140 lines (118 loc) · 4.82 KB
/
caffe_trainloss_net_shell.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
function [ best_loss, best_iter, stat ] = caffe_trainloss_net_shell( solver_filename, best_snapshot_prefix, varargin )
%% Description:
% training with saving best snapshot using shell (since I have problems
% with matcaffe reporting wrong accuracy)
% WARNING: for now, make sure that testing and saving of snapshots happens
% with the same interval
% [ best_loss, best_iter, stat ] = caffe_train_net_shell( solver_filename, best_snapshot_prefix, [log_filename], [clear_other_snapshots = 1] )
% --- INPUT:
% solver_filename = solver to use
% best_snapshot_prefix = all temporary snapshots will be saved with the
% prefix reported in solver protobuf file, but at the end they all be
% erased except the one with the best loss, which will be moved
% (renamed) according to the best_snapshot_prefix
% --- OPTIONAL
% log_filename = log with training losses/accuracies reported
% clear_other_snapshots = deletes non best snapshots at the end (warning:
% be careful. Make sure that the folder where you are saving temporary
% snapshots does not contain important snapshots)
% --- OUTPUT
% best_loss = best loss
% best_iter = best iteration
% stat = contains losses reported during training (from parsing the log
% file)
%% Parameters
%% Arguments
var_i = 1;
[snap_p] = fileparts(best_snapshot_prefix);
log_filename = [ snap_p '/training_log.txt' ];
if length(varargin) >= var_i
log_filename = varargin{var_i};
else
if ~exist(snap_p, 'dir')
fprintf('WARNING: %s : Dir = %s didnt exist, so it was created ...\n', ...
mfilename, ...
snap_p);
mkdir(snap_p);
end
end
fprintf('%s : log filename = %s \n', mfilename, log_filename);
var_i = 2;
clear_other_snapshots = 1;
if length(varargin) >= var_i
clear_other_snapshots = varargin{var_i};
end
%% Execution
solver_params = caffe_read_solverprototxt(solver_filename);
% solver_params.snapshot_prefix;
% --- Run training
system(['caffe train --solver=' solver_filename ' 2>' log_filename]);
% --- Analyze log file
% [stat.train, test_stat] = caffe_log_proc(log_filename);
[stat.train, test_stat] = caffe_log_proc2(log_filename);
% --- Pick best iteration
best_iter_i = 0;
best_iter = 0;
best_loss = realmax('single');
for i=1:length(test_stat)
if (test_stat{i}.loss < best_loss) || (best_iter == 0)
best_loss = test_stat{i}.loss;
best_iter_i = i;
best_iter = test_stat{i}.iter;
end
end
model_ext = '.caffemodel';
solver_ext = '.solverstate';
best_snapshot_temp_name = sprintf('%s_iter_%d', ...
solver_params.snapshot_prefix, ...
best_iter );
best_snapshot_name = sprintf('%s__iter_%06d__loss_%5.10f', ...
best_snapshot_prefix, ...
best_iter, ...
best_loss );
stat.best_snapshot_name = [best_snapshot_name model_ext];
[best_snapshot_path best_snapshot_name_only best_snapshot_ext_only] = fileparts(best_snapshot_name);
best_snapshot_name_only = [best_snapshot_name_only best_snapshot_ext_only];
%because sometimes snapshot prefix contains '.' and fileparts thinks it is
%an extension
% --- Creating folders if they don't exist
snap_p = fileparts(best_snapshot_name);
if ~exist(snap_p, 'dir')
fprintf('WARNING: %s : Dir = %s didnt exist, so it was created ...\n', ...
mfilename, ...
snap_p);
mkdir(snap_p);
end
% --- Rename files with the best iteration
fprintf('%s : renaming %s to %s \n', mfilename, [best_snapshot_temp_name model_ext], stat.best_snapshot_name);
movefile( [best_snapshot_temp_name model_ext], ...
stat.best_snapshot_name );
fprintf('%s : renaming %s to %s \n', mfilename, [best_snapshot_temp_name solver_ext], [best_snapshot_name solver_ext]);
movefile( [best_snapshot_temp_name solver_ext] , ...
[best_snapshot_name solver_ext]);
% --- Delete non best snapshots
if clear_other_snapshots
% system( ['rm -rf ' solver_params.snapshot_prefix '*' model_ext ] );
% system( ['rm -rf ' solver_params.snapshot_prefix '*' solver_ext ] );
delete_except( [solver_params.snapshot_prefix '*' model_ext], [ best_snapshot_name_only model_ext] );
delete_except( [solver_params.snapshot_prefix '*' solver_ext], [ best_snapshot_name_only solver_ext] );
end
% --- Copy data to stat
iterations_num = length(test_stat);
stat.iterations = zeros(1, iterations_num);
stat.loss = zeros(1, iterations_num);
stat.loss_train_sync = zeros(1, iterations_num);
stat.loss_train = zeros(1 , length(length(stat.train)) );
stat.iterations_train = zeros(1 , length(length(stat.train)) );
for iter_i=1:length(stat.train)
stat.loss_train(iter_i) = stat.train{iter_i}.loss;
stat.iterations_train(iter_i) = stat.train{iter_i}.iter;
end
for iter_i=1:iterations_num
stat.iterations(iter_i) = test_stat{iter_i}.iter;
stat.loss_test(iter_i) = test_stat{iter_i}.loss;
stat.loss_train_sync(iter_i) = ...
stat.loss_train( stat.iterations_train == stat.iterations(iter_i) );
end
stat.loss = stat.loss_test;
end